commonware_cryptography/bls12381/primitives/
sharing.rs1use crate::bls12381::primitives::{group::Scalar, variant::Variant, Error};
2#[cfg(not(feature = "std"))]
3use alloc::sync::Arc;
4#[cfg(not(feature = "std"))]
5use alloc::vec::Vec;
6use cfg_if::cfg_if;
7use commonware_codec::{EncodeSize, FixedSize, RangeCfg, Read, ReadExt, Write};
8use commonware_macros::stability;
9#[stability(ALPHA)]
10use commonware_math::algebra::{FieldNTT, Ring};
11use commonware_math::poly::{Interpolator, Poly};
12use commonware_parallel::Sequential;
13#[stability(ALPHA)]
14use commonware_utils::{ordered::BiMap, TryFromIterator};
15use commonware_utils::{ordered::Set, Faults, Participant, NZU32};
16#[cfg(feature = "std")]
17use core::iter;
18use core::num::NonZeroU32;
19#[cfg(feature = "std")]
20use std::sync::{Arc, OnceLock};
21#[cfg(feature = "std")]
22use std::vec::Vec;
23
24#[derive(Copy, Clone, Default, PartialEq, Eq, Debug)]
29#[repr(u8)]
30pub enum Mode {
31 #[default]
32 NonZeroCounter = 0,
33
34 #[cfg(not(any(
38 commonware_stability_BETA,
39 commonware_stability_GAMMA,
40 commonware_stability_DELTA,
41 commonware_stability_EPSILON,
42 commonware_stability_RESERVED
43 )))]
44 RootsOfUnity = 1,
45}
46
47impl Mode {
48 pub(crate) fn scalar(self, total: NonZeroU32, i: Participant) -> Option<Scalar> {
52 if i.get() >= total.get() {
53 return None;
54 }
55 match self {
56 Self::NonZeroCounter => {
57 Some(Scalar::from_u64(i.get() as u64 + 1))
59 }
60 #[cfg(not(any(
61 commonware_stability_BETA,
62 commonware_stability_GAMMA,
63 commonware_stability_DELTA,
64 commonware_stability_EPSILON,
65 commonware_stability_RESERVED
66 )))]
67 Self::RootsOfUnity => {
68 let size = (total.get() as u64).next_power_of_two();
71 let lg_size = size.ilog2() as u8;
72 let w = Scalar::root_of_unity(lg_size).expect("domain too large for NTT");
73 Some(w.exp(&[i.get() as u64]))
74 }
75 }
76 }
77
78 #[cfg(feature = "std")]
80 pub(crate) fn all_scalars(self, total: NonZeroU32) -> Vec<Scalar> {
81 match self {
82 Self::NonZeroCounter => (0..total.get())
83 .map(|i| Scalar::from_u64(i as u64 + 1))
84 .collect(),
85 #[cfg(not(any(
86 commonware_stability_BETA,
87 commonware_stability_GAMMA,
88 commonware_stability_DELTA,
89 commonware_stability_EPSILON,
90 commonware_stability_RESERVED
91 )))]
92 Self::RootsOfUnity => {
93 let size = (total.get() as u64).next_power_of_two();
94 let lg_size = size.ilog2() as u8;
95 let w = Scalar::root_of_unity(lg_size).expect("domain too large for NTT");
96 (0..total.get())
97 .scan(Scalar::one(), |state, _| {
98 let val = state.clone();
99 *state *= &w;
100 Some(val)
101 })
102 .collect()
103 }
104 }
105 }
106
107 fn interpolator<I: Clone + Ord>(
118 self,
119 total: NonZeroU32,
120 indices: &Set<I>,
121 to_index: impl Fn(&I) -> Option<Participant>,
122 ) -> Option<Interpolator<I, Scalar>> {
123 match self {
124 Self::NonZeroCounter => {
125 let mut count = 0;
126 let iter = indices
127 .iter()
128 .filter_map(|i| {
129 let scalar = self.scalar(total, to_index(i)?)?;
130 Some((i.clone(), scalar))
131 })
132 .inspect(|_| {
133 count += 1;
134 });
135 let out = Interpolator::new(iter);
136 if count != indices.len() {
138 return None;
139 }
140 Some(out)
141 }
142 #[cfg(not(any(
143 commonware_stability_BETA,
144 commonware_stability_GAMMA,
145 commonware_stability_DELTA,
146 commonware_stability_EPSILON,
147 commonware_stability_RESERVED
148 )))]
149 Self::RootsOfUnity => {
150 let size = (total.get() as u64).next_power_of_two();
153 let ntt_total = NonZeroU32::new(u32::try_from(size).ok()?)?;
154
155 let mut count = 0;
156 let points: Vec<(I, u32)> = indices
157 .iter()
158 .filter_map(|i| {
159 let participant = to_index(i)?;
160 if participant.get() >= total.get() {
161 return None;
162 }
163 count += 1;
164 Some((i.clone(), participant.get()))
165 })
166 .collect();
167
168 if count != indices.len() {
170 return None;
171 }
172
173 let points = BiMap::try_from_iter(points).ok()?;
174 Some(Interpolator::roots_of_unity(ntt_total, points))
175 }
176 }
177 }
178
179 #[cfg(feature = "std")]
187 pub(crate) fn subset_interpolator<I: Clone + Ord>(
188 self,
189 set: &Set<I>,
190 subset: &Set<I>,
191 ) -> Option<Interpolator<I, Scalar>> {
192 let Ok(total) = NonZeroU32::try_from(set.len() as u32) else {
193 return Some(Interpolator::new(iter::empty()));
194 };
195 self.interpolator(total, subset, |i| {
196 set.position(i).map(Participant::from_usize)
197 })
198 }
199}
200
201impl FixedSize for Mode {
202 const SIZE: usize = 1;
203}
204
205impl Write for Mode {
206 fn write(&self, buf: &mut impl bytes::BufMut) {
207 buf.put_u8(*self as u8);
208 }
209}
210
211#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
219pub struct ModeVersion(u8);
220
221impl ModeVersion {
222 pub const fn v0() -> Self {
226 Self(0)
227 }
228
229 #[stability(ALPHA)]
233 pub const fn v1() -> Self {
234 Self(1)
235 }
236
237 const fn supports(&self, mode: &Mode) -> bool {
238 match mode {
239 Mode::NonZeroCounter => true,
240 #[cfg(not(any(
241 commonware_stability_BETA,
242 commonware_stability_GAMMA,
243 commonware_stability_DELTA,
244 commonware_stability_EPSILON,
245 commonware_stability_RESERVED
246 )))]
247 Mode::RootsOfUnity => self.0 >= 1,
248 }
249 }
250}
251
252impl Read for Mode {
253 type Cfg = ModeVersion;
254
255 fn read_cfg(
256 buf: &mut impl bytes::Buf,
257 version: &Self::Cfg,
258 ) -> Result<Self, commonware_codec::Error> {
259 let tag: u8 = ReadExt::read(buf)?;
260 let mode = match tag {
261 0 => Self::NonZeroCounter,
262 #[cfg(not(any(
263 commonware_stability_BETA,
264 commonware_stability_GAMMA,
265 commonware_stability_DELTA,
266 commonware_stability_EPSILON,
267 commonware_stability_RESERVED
268 )))]
269 1 => Self::RootsOfUnity,
270 o => return Err(commonware_codec::Error::InvalidEnum(o)),
271 };
272 if !version.supports(&mode) {
273 return Err(commonware_codec::Error::Invalid(
274 "Mode",
275 "unsupported mode for version",
276 ));
277 }
278 Ok(mode)
279 }
280}
281
282#[derive(Clone, Debug)]
286pub struct Sharing<V: Variant> {
287 mode: Mode,
288 total: NonZeroU32,
289 poly: Arc<Poly<V::Public>>,
290 #[cfg(feature = "std")]
291 evals: Arc<Vec<OnceLock<V::Public>>>,
292}
293
294impl<V: Variant> PartialEq for Sharing<V> {
295 fn eq(&self, other: &Self) -> bool {
296 self.mode == other.mode && self.total == other.total && self.poly == other.poly
297 }
298}
299
300impl<V: Variant> Eq for Sharing<V> {}
301
302impl<V: Variant> Sharing<V> {
303 pub(crate) fn new(mode: Mode, total: NonZeroU32, poly: Poly<V::Public>) -> Self {
304 Self {
305 mode,
306 total,
307 poly: Arc::new(poly),
308 #[cfg(feature = "std")]
309 evals: Arc::new(vec![OnceLock::new(); total.get() as usize]),
310 }
311 }
312
313 #[cfg(feature = "std")]
315 pub(crate) const fn mode(&self) -> Mode {
316 self.mode
317 }
318
319 pub(crate) fn scalar(&self, i: Participant) -> Option<Scalar> {
320 self.mode.scalar(self.total, i)
321 }
322
323 #[cfg(feature = "std")]
324 fn all_scalars(&self) -> Vec<Scalar> {
325 self.mode.all_scalars(self.total)
326 }
327
328 pub fn required<M: Faults>(&self) -> u32 {
331 M::quorum(self.total.get())
332 }
333
334 pub const fn total(&self) -> NonZeroU32 {
336 self.total
337 }
338
339 pub(crate) fn interpolator(
343 &self,
344 indices: &Set<Participant>,
345 ) -> Result<Interpolator<Participant, Scalar>, Error> {
346 self.mode
347 .interpolator(self.total, indices, |&x| Some(x))
348 .ok_or(Error::InvalidIndex)
349 }
350
351 #[cfg(feature = "std")]
359 pub fn precompute_partial_publics(&self) {
360 self.evals
362 .iter()
363 .zip(self.all_scalars())
364 .for_each(|(e, s)| {
365 e.get_or_init(|| self.poly.eval_msm(&s, &Sequential));
366 })
367 }
368
369 pub fn partial_public(&self, i: Participant) -> Result<V::Public, Error> {
373 cfg_if! {
374 if #[cfg(feature = "std")] {
375 self.evals
376 .get(usize::from(i))
377 .map(|e| {
378 *e.get_or_init(|| {
379 self.poly
380 .eval_msm(&self.scalar(i).expect("i < total"), &Sequential)
381 })
382 })
383 .ok_or(Error::InvalidIndex)
384 } else {
385 Ok(self
386 .poly
387 .eval_msm(&self.scalar(i).ok_or(Error::InvalidIndex)?, &Sequential))
388 }
389 }
390 }
391
392 pub fn public(&self) -> &V::Public {
396 self.poly.constant()
397 }
398}
399
400impl<V: Variant> EncodeSize for Sharing<V> {
401 fn encode_size(&self) -> usize {
402 self.mode.encode_size() + self.total.get().encode_size() + self.poly.encode_size()
403 }
404}
405
406impl<V: Variant> Write for Sharing<V> {
407 fn write(&self, buf: &mut impl bytes::BufMut) {
408 self.mode.write(buf);
409 self.total.get().write(buf);
410 self.poly.write(buf);
411 }
412}
413
414impl<V: Variant> Read for Sharing<V> {
415 type Cfg = (NonZeroU32, ModeVersion);
416
417 fn read_cfg(
418 buf: &mut impl bytes::Buf,
419 (max_participants, max_supported_mode): &Self::Cfg,
420 ) -> Result<Self, commonware_codec::Error> {
421 let mode = Read::read_cfg(buf, max_supported_mode)?;
422 let total = {
425 let out: u32 = ReadExt::read(buf)?;
426 if out == 0 || out > max_participants.get() {
427 return Err(commonware_codec::Error::Invalid(
428 "Sharing",
429 "total not in range",
430 ));
431 }
432 NZU32!(out)
434 };
435 let poly = Read::read_cfg(buf, &(RangeCfg::from(NZU32!(1)..=*max_participants), ()))?;
436 Ok(Self::new(mode, total, poly))
437 }
438}
439
440#[cfg(all(test, feature = "std"))]
441mod tests {
442 use super::*;
443 use commonware_invariants::minifuzz;
444 use commonware_utils::ordered::Map;
445 use rand::{rngs::StdRng, SeedableRng};
446
447 #[test]
448 fn test_roots_of_unity_interpolator_large_total_returns_none() {
449 let total = NonZeroU32::new(u32::MAX).expect("u32::MAX is non-zero");
450 let indices = Set::from_iter_dedup([Participant::new(0)]);
451 let interpolator =
452 Mode::RootsOfUnity.interpolator(total, &indices, |participant| Some(*participant));
453 assert!(
454 interpolator.is_none(),
455 "domain > u32::MAX should be rejected instead of panicking"
456 );
457 }
458
459 #[test]
460 fn test_mode_read_rejects_mode_above_max_supported_mode() {
461 let encoded = [Mode::RootsOfUnity as u8];
462 Mode::read_cfg(&mut &encoded[..], &ModeVersion::v0())
463 .expect_err("roots mode must be rejected when max mode is counter");
464 }
465
466 #[test]
467 fn test_all_scalars_matches_scalar() {
468 minifuzz::test(|u| {
469 let mode = match u.int_in_range(0u8..=1)? {
470 0 => Mode::NonZeroCounter,
471 1 => Mode::RootsOfUnity,
472 _ => unreachable!("range is 0..=1"),
473 };
474 let total = NonZeroU32::new(u.int_in_range(1u32..=512u32)?).expect("range is non-zero");
475 let index = u.int_in_range(0u32..=total.get() - 1)?;
476 let participant = Participant::new(index);
477
478 let scalars = mode.all_scalars(total);
479 assert_eq!(
480 scalars[usize::from(participant)].clone(),
481 mode.scalar(total, participant).expect("index is in range")
482 );
483 Ok(())
484 });
485 }
486
487 #[test]
488 fn test_subset_interpolation_recovers_constant() {
489 minifuzz::test(|u| {
490 let mode = match u.int_in_range(0u8..=1)? {
491 0 => Mode::NonZeroCounter,
492 1 => Mode::RootsOfUnity,
493 _ => unreachable!("range is 0..=1"),
494 };
495 let total = NonZeroU32::new(u.int_in_range(1u32..=64u32)?).expect("range is non-zero");
496
497 let mut subset_vec = Vec::new();
498 for i in 0..total.get() {
499 if u.arbitrary::<bool>()? {
500 subset_vec.push(Participant::new(i));
501 }
502 }
503 if subset_vec.is_empty() {
504 let i = u.int_in_range(0u32..=total.get() - 1)?;
505 subset_vec.push(Participant::new(i));
506 }
507 let subset = Set::from_iter_dedup(subset_vec);
508
509 let max_degree = u32::try_from(subset.len() - 1).expect("subset len fits in u32");
510 let degree = u.int_in_range(0u32..=max_degree)?;
511 let seed: u64 = u.arbitrary()?;
512 let poly: Poly<Scalar> = Poly::new(&mut StdRng::seed_from_u64(seed), degree);
513
514 let all_shares = Map::from_iter_dedup((0..total.get()).map(|i| {
515 let participant = Participant::new(i);
516 let scalar = mode.scalar(total, participant).expect("in range");
517 let share = poly.eval(&scalar);
518 (participant, share)
519 }));
520
521 let subset_evals = Map::from_iter_dedup(subset.iter().map(|participant| {
522 (
523 *participant,
524 all_shares
525 .get_value(participant)
526 .expect("participant exists")
527 .clone(),
528 )
529 }));
530
531 let interpolator = mode
532 .interpolator(total, &subset, |participant| Some(*participant))
533 .expect("subset indices are valid");
534 let recovered = interpolator
535 .interpolate(&subset_evals, &Sequential)
536 .expect("subset should match interpolator domain");
537
538 assert_eq!(recovered, poly.constant().clone());
539 Ok(())
540 });
541 }
542}
543
544#[cfg(feature = "arbitrary")]
545mod fuzz {
546 use super::*;
547 use arbitrary::Arbitrary;
548 use commonware_utils::{N3f1, NZU32};
549 use rand::{rngs::StdRng, SeedableRng};
550
551 impl<'a> Arbitrary<'a> for Mode {
552 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
553 match u.int_in_range(0u8..=1)? {
554 0 => Ok(Self::NonZeroCounter),
555 1 => Ok(Self::RootsOfUnity),
556 _ => Err(arbitrary::Error::IncorrectFormat),
557 }
558 }
559 }
560
561 impl<'a, V: Variant> Arbitrary<'a> for Sharing<V> {
562 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
563 let total: u32 = u.int_in_range(1..=100)?;
564 let mode: Mode = u.arbitrary()?;
565 let seed: u64 = u.arbitrary()?;
566 let poly = Poly::new(&mut StdRng::seed_from_u64(seed), N3f1::quorum(total) - 1);
567 Ok(Self::new(
568 mode,
569 NZU32!(total),
570 Poly::<V::Public>::commit(poly),
571 ))
572 }
573 }
574}