1use core::fmt::Display;
2use core::{
3 fmt::{self, Debug, Formatter},
4 marker::PhantomData,
5 num::NonZeroUsize,
6};
7use rand_core::{CryptoRng, RngCore};
8use sha3::digest::ExtendableOutput;
9use sha3::{
10 digest::{Update, XofReader},
11 Shake256,
12};
13
14use crate::{Error, ShareIdentifier, VsssResult};
15
16#[derive(Debug, Clone)]
18pub enum ParticipantIdGeneratorType<'a, I: ShareIdentifier> {
19 Sequential {
22 start: I,
24 increment: I,
26 count: usize,
28 },
29 Random {
32 seed: [u8; 32],
34 count: usize,
36 },
37 List {
39 list: &'a [I],
41 },
42}
43
44impl<'a, I: ShareIdentifier + Copy> Copy for ParticipantIdGeneratorType<'a, I> {}
45
46impl<I: ShareIdentifier + Display> Display for ParticipantIdGeneratorType<'_, I> {
47 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
48 match self {
49 Self::Sequential {
50 start,
51 increment,
52 count,
53 } => write!(
54 f,
55 "Sequential {{ start: {}, increment: {}, count: {} }}",
56 start, increment, count
57 ),
58 Self::Random { seed, count } => {
59 write!(f, "Random {{ seed: ")?;
60 for &b in seed {
61 write!(f, "{:02x}", b)?;
62 }
63 write!(f, ", count: {} }}", count)
64 }
65 Self::List { list } => {
66 write!(f, "List {{ list: ")?;
67 for id in list.iter() {
68 write!(f, "{}, ", id)?;
69 }
70 write!(f, "}}")
71 }
72 }
73 }
74}
75
76impl<I: ShareIdentifier> Default for ParticipantIdGeneratorType<'_, I> {
77 fn default() -> Self {
78 Self::Sequential {
79 start: I::one(),
80 increment: I::one(),
81 count: u16::MAX as usize,
82 }
83 }
84}
85
86#[cfg(any(feature = "alloc", feature = "std"))]
87impl<'a, I: ShareIdentifier> From<&'a crate::Vec<I>> for ParticipantIdGeneratorType<'a, I> {
88 fn from(list: &'a crate::Vec<I>) -> Self {
89 Self::List { list }
90 }
91}
92
93impl<'a, I: ShareIdentifier> From<&'a [I]> for ParticipantIdGeneratorType<'a, I> {
94 fn from(list: &'a [I]) -> Self {
95 Self::List { list }
96 }
97}
98
99impl<'a, I: ShareIdentifier> ParticipantIdGeneratorType<'a, I> {
100 pub fn sequential(start: Option<I>, increment: Option<I>, count: NonZeroUsize) -> Self {
102 Self::Sequential {
103 start: start.unwrap_or_else(I::one),
104 increment: increment.unwrap_or_else(I::one),
105 count: count.get(),
106 }
107 }
108
109 pub fn random(seed: [u8; 32], count: NonZeroUsize) -> Self {
111 Self::Random {
112 seed,
113 count: count.get(),
114 }
115 }
116
117 pub fn list(list: &'a [I]) -> Self {
119 Self::List { list }
120 }
121
122 pub(crate) fn try_into_generator(&self) -> VsssResult<ParticipantIdGeneratorState<'a, I>> {
123 match self {
124 Self::Sequential {
125 start,
126 increment,
127 count,
128 } => {
129 if *count == 0 {
130 return Err(Error::InvalidGenerator(
131 "The count must be greater than zero",
132 ));
133 }
134 Ok(ParticipantIdGeneratorState::Sequential(
135 SequentialParticipantNumberGenerator {
136 start: start.clone(),
137 increment: increment.clone(),
138 index: 0,
139 count: *count,
140 },
141 ))
142 }
143 Self::Random { seed, count } => {
144 if *count == 0 {
145 return Err(Error::InvalidGenerator(
146 "The count must be greater than zero",
147 ));
148 }
149 Ok(ParticipantIdGeneratorState::Random(
150 RandomParticipantNumberGenerator {
151 dst: *seed,
152 index: 0,
153 count: *count,
154 _markers: PhantomData,
155 },
156 ))
157 }
158 Self::List { list } => Ok(ParticipantIdGeneratorState::List(
159 ListParticipantNumberGenerator { list, index: 0 },
160 )),
161 }
162 }
163}
164
165#[derive(Debug, Clone)]
167pub struct ParticipantIdGeneratorCollection<'a, 'b, I: ShareIdentifier> {
168 pub generators: &'a [ParticipantIdGeneratorType<'b, I>],
170}
171
172impl<'a, 'b, I: ShareIdentifier + Copy> Copy for ParticipantIdGeneratorCollection<'a, 'b, I> {}
173
174impl<'a, 'b, I: ShareIdentifier> From<&'a [ParticipantIdGeneratorType<'b, I>]>
175 for ParticipantIdGeneratorCollection<'a, 'b, I>
176{
177 fn from(generators: &'a [ParticipantIdGeneratorType<'b, I>]) -> Self {
178 Self { generators }
179 }
180}
181
182impl<'a, 'b, I: ShareIdentifier, const L: usize> From<&'a [ParticipantIdGeneratorType<'b, I>; L]>
183 for ParticipantIdGeneratorCollection<'a, 'b, I>
184{
185 fn from(generators: &'a [ParticipantIdGeneratorType<'b, I>; L]) -> Self {
186 Self { generators }
187 }
188}
189
190#[cfg(any(feature = "alloc", feature = "std"))]
191impl<'a, 'b, I: ShareIdentifier> From<&'a crate::Vec<ParticipantIdGeneratorType<'b, I>>>
192 for ParticipantIdGeneratorCollection<'a, 'b, I>
193{
194 fn from(generators: &'a crate::Vec<ParticipantIdGeneratorType<'b, I>>) -> Self {
195 Self {
196 generators: generators.as_slice(),
197 }
198 }
199}
200
201impl<'a, 'b, I: ShareIdentifier> ParticipantIdGeneratorCollection<'a, 'b, I> {
202 pub fn iter(&self) -> impl Iterator<Item = I> + '_ {
207 let mut participant_id_iter = self.generators.iter().map(|g| g.try_into_generator());
208 let mut current: Option<ParticipantIdGeneratorState<'a, I>> = None;
209 core::iter::from_fn(move || {
210 loop {
211 if let Some(ref mut generator) = current {
212 match generator.next() {
213 Some(id) => {
214 if id.is_zero().into() {
215 current = None; continue;
217 }
218 return Some(id);
219 }
220 None => {
221 current = None; }
223 }
224 }
225
226 match participant_id_iter.next() {
228 Some(Ok(new_generator)) => {
229 current = Some(new_generator);
230 }
232 Some(Err(_)) => return None, None => return None, }
235 }
236 })
237 }
238}
239
240pub(crate) enum ParticipantIdGeneratorState<'a, I: ShareIdentifier> {
241 Sequential(SequentialParticipantNumberGenerator<I>),
242 Random(RandomParticipantNumberGenerator<I>),
243 List(ListParticipantNumberGenerator<'a, I>),
244}
245
246impl<'a, I: ShareIdentifier> Iterator for ParticipantIdGeneratorState<'a, I> {
247 type Item = I;
248
249 fn next(&mut self) -> Option<Self::Item> {
250 match self {
251 Self::Sequential(gen) => gen.next(),
252 Self::Random(gen) => gen.next(),
253 Self::List(gen) => gen.next(),
254 }
255 }
256}
257
258#[derive(Debug)]
259pub(crate) struct SequentialParticipantNumberGenerator<I: ShareIdentifier> {
261 start: I,
262 increment: I,
263 index: usize,
264 count: usize,
265}
266
267impl<I: ShareIdentifier> Iterator for SequentialParticipantNumberGenerator<I> {
268 type Item = I;
269
270 fn next(&mut self) -> Option<Self::Item> {
271 if self.index >= self.count {
272 return None;
273 }
274 let value = self.start.clone();
275 self.start.inc(&self.increment);
276 self.index += 1;
277 Some(value)
278 }
279}
280
281#[derive(Debug)]
283pub(crate) struct RandomParticipantNumberGenerator<I: ShareIdentifier> {
284 dst: [u8; 32],
286 index: usize,
287 count: usize,
288 _markers: PhantomData<I>,
289}
290
291impl<I: ShareIdentifier> Iterator for RandomParticipantNumberGenerator<I> {
292 type Item = I;
293
294 fn next(&mut self) -> Option<Self::Item> {
295 if self.index >= self.count {
296 return None;
297 }
298 self.index += 1;
299 Some(I::random(self.get_rng(self.index)))
300 }
301}
302
303impl<I: ShareIdentifier> RandomParticipantNumberGenerator<I> {
304 fn get_rng(&self, index: usize) -> XofRng {
305 let mut hasher = Shake256::default();
306 hasher.update(&self.dst);
307 hasher.update(&index.to_be_bytes());
308 hasher.update(&self.count.to_be_bytes());
309 XofRng(hasher.finalize_xof())
310 }
311}
312
313#[derive(Debug)]
315pub(crate) struct ListParticipantNumberGenerator<'a, I: ShareIdentifier> {
316 list: &'a [I],
317 index: usize,
318}
319
320impl<'a, I: ShareIdentifier> Iterator for ListParticipantNumberGenerator<'a, I> {
321 type Item = I;
322
323 fn next(&mut self) -> Option<Self::Item> {
324 if self.index >= self.list.len() {
325 return None;
326 }
327 let index = self.index;
328 self.index += 1;
329 Some(self.list[index].clone())
330 }
331}
332
333#[derive(Clone)]
334#[repr(transparent)]
335struct XofRng(<Shake256 as ExtendableOutput>::Reader);
336
337impl RngCore for XofRng {
338 fn next_u32(&mut self) -> u32 {
339 let mut buf = [0u8; 4];
340 self.0.read(&mut buf);
341 u32::from_be_bytes(buf)
342 }
343
344 fn next_u64(&mut self) -> u64 {
345 let mut buf = [0u8; 8];
346 self.0.read(&mut buf);
347 u64::from_be_bytes(buf)
348 }
349
350 fn fill_bytes(&mut self, dest: &mut [u8]) {
351 self.0.read(dest);
352 }
353
354 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> {
355 self.0.read(dest);
356 Ok(())
357 }
358}
359
360impl CryptoRng for XofRng {}
361
362impl Debug for XofRng {
363 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
364 write!(f, "XofRng")
365 }
366}
367
368#[cfg(all(test, any(feature = "alloc", feature = "std")))]
369mod tests {
370 use super::*;
371 use crate::*;
372 use elliptic_curve::PrimeField;
373 use k256::{FieldBytes, Scalar};
374 use rand_core::SeedableRng;
375
376 #[cfg(any(feature = "alloc", feature = "std"))]
377 #[test]
378 fn test_sequential_participant_number_generator() {
379 let gen = SequentialParticipantNumberGenerator::<IdentifierPrimeField<Scalar>> {
380 start: IdentifierPrimeField::<Scalar>::ONE,
381 increment: IdentifierPrimeField::<Scalar>::ONE,
382 index: 0,
383 count: 5,
384 };
385 let list: Vec<_> = gen.collect();
386 assert_eq!(list.len(), 5);
387 assert_eq!(list[0], IdentifierPrimeField::from(Scalar::from(1u64)));
388 assert_eq!(list[1], IdentifierPrimeField::from(Scalar::from(2u64)));
389 assert_eq!(list[2], IdentifierPrimeField::from(Scalar::from(3u64)));
390 assert_eq!(list[3], IdentifierPrimeField::from(Scalar::from(4u64)));
391 assert_eq!(list[4], IdentifierPrimeField::from(Scalar::from(5u64)));
392 }
393
394 #[cfg(any(feature = "alloc", feature = "std"))]
395 #[test]
396 fn test_random_participant_number_generator() {
397 let mut rng = rand_chacha::ChaCha8Rng::from_seed([1u8; 32]);
398 let mut dst = [0u8; 32];
399 rng.fill_bytes(&mut dst);
400 let gen = RandomParticipantNumberGenerator::<IdentifierPrimeField<Scalar>> {
401 dst,
402 index: 0,
403 count: 5,
404 _markers: PhantomData,
405 };
406 let list: Vec<_> = gen.collect();
407 assert_eq!(list.len(), 5);
408 let mut repr = FieldBytes::default();
409 for (i, s) in [
410 "134de46908fd0867a9c14ed96e90cd34be47e2b052ca266499687adae4cfe445",
411 "5b182d31afa277bcfb5d6316c31e231004d29f2c99e4dec0c384d7a46439c8ca",
412 "cb15c36dfe7b15c253e3f9fde1fd9ccfbd75839ff6dccca49700cb831dc5802e",
413 "bb3a92d716f6a8d94d82295fd120b23d42ec8543a405ecd82e519ab0fe4ef965",
414 "a0fff4c9e992f0d1acc8bc90fe6ae31dee280a0175a028a6333dde56de2121ec",
415 ]
416 .iter()
417 .enumerate()
418 {
419 repr.copy_from_slice(&hex::decode(s).unwrap());
420 assert_eq!(
421 list[i],
422 IdentifierPrimeField::from(Scalar::from_repr(repr).unwrap())
423 );
424 }
425 }
426
427 #[cfg(any(feature = "alloc", feature = "std"))]
428 #[test]
429 fn test_list_participant_number_generator() {
430 let list = [
431 IdentifierPrimeField::from(Scalar::from(10u64)),
432 IdentifierPrimeField::from(Scalar::from(20u64)),
433 IdentifierPrimeField::from(Scalar::from(30u64)),
434 IdentifierPrimeField::from(Scalar::from(40u64)),
435 IdentifierPrimeField::from(Scalar::from(50u64)),
436 ];
437 let gen = ListParticipantNumberGenerator {
438 list: &list,
439 index: 0,
440 };
441 let list: Vec<_> = gen.collect();
442 assert_eq!(list.len(), 5);
443 assert_eq!(list[0], IdentifierPrimeField::from(Scalar::from(10u64)));
444 assert_eq!(list[1], IdentifierPrimeField::from(Scalar::from(20u64)));
445 assert_eq!(list[2], IdentifierPrimeField::from(Scalar::from(30u64)));
446 assert_eq!(list[3], IdentifierPrimeField::from(Scalar::from(40u64)));
447 assert_eq!(list[4], IdentifierPrimeField::from(Scalar::from(50u64)));
448 }
449
450 #[test]
451 fn test_list_and_sequential_number_generator() {
452 let list = [
453 IdentifierPrimeField::from(Scalar::from(10u64)),
454 IdentifierPrimeField::from(Scalar::from(20u64)),
455 IdentifierPrimeField::from(Scalar::from(30u64)),
456 IdentifierPrimeField::from(Scalar::from(40u64)),
457 IdentifierPrimeField::from(Scalar::from(50u64)),
458 ];
459 let set = [
460 ParticipantIdGeneratorType::list(&list),
461 ParticipantIdGeneratorType::sequential(
462 Some(IdentifierPrimeField::from(Scalar::from(51u64))),
463 Some(IdentifierPrimeField::<Scalar>::ONE),
464 NonZeroUsize::new(5).unwrap(),
465 ),
466 ];
467 let collection = ParticipantIdGeneratorCollection::from(&set[..]);
468
469 let expected = [
470 IdentifierPrimeField::from(Scalar::from(10u64)),
471 IdentifierPrimeField::from(Scalar::from(20u64)),
472 IdentifierPrimeField::from(Scalar::from(30u64)),
473 IdentifierPrimeField::from(Scalar::from(40u64)),
474 IdentifierPrimeField::from(Scalar::from(50u64)),
475 IdentifierPrimeField::from(Scalar::from(51u64)),
476 IdentifierPrimeField::from(Scalar::from(52u64)),
477 IdentifierPrimeField::from(Scalar::from(53u64)),
478 IdentifierPrimeField::from(Scalar::from(54u64)),
479 IdentifierPrimeField::from(Scalar::from(55u64)),
480 ];
481 let mut last_i = 0;
482 for (i, id) in collection.iter().enumerate() {
483 assert_eq!(id, expected[i]);
484 last_i = i;
485 }
486 assert_eq!(last_i, expected.len() - 1);
487 }
488
489 #[test]
490 fn test_list_and_random_number_generator() {
491 let list = [
492 IdentifierPrimeField::from(Scalar::from(10u64)),
493 IdentifierPrimeField::from(Scalar::from(20u64)),
494 IdentifierPrimeField::from(Scalar::from(30u64)),
495 IdentifierPrimeField::from(Scalar::from(40u64)),
496 IdentifierPrimeField::from(Scalar::from(50u64)),
497 ];
498 let mut rng = rand_chacha::ChaCha8Rng::from_seed([1u8; 32]);
499 let mut dst = [0u8; 32];
500 rng.fill_bytes(&mut dst);
501 let set = [
502 ParticipantIdGeneratorType::list(&list),
503 ParticipantIdGeneratorType::random(dst, NonZeroUsize::new(5).unwrap()),
504 ];
505 let collection = ParticipantIdGeneratorCollection::from(&set);
506 let expected = [
507 IdentifierPrimeField::from(Scalar::from(10u64)),
508 IdentifierPrimeField::from(Scalar::from(20u64)),
509 IdentifierPrimeField::from(Scalar::from(30u64)),
510 IdentifierPrimeField::from(Scalar::from(40u64)),
511 IdentifierPrimeField::from(Scalar::from(50u64)),
512 hex::decode("134de46908fd0867a9c14ed96e90cd34be47e2b052ca266499687adae4cfe445")
513 .map(|b| {
514 IdentifierPrimeField::from(
515 Scalar::from_repr(FieldBytes::clone_from_slice(&b)).unwrap(),
516 )
517 })
518 .unwrap(),
519 hex::decode("5b182d31afa277bcfb5d6316c31e231004d29f2c99e4dec0c384d7a46439c8ca")
520 .map(|b| {
521 IdentifierPrimeField::from(
522 Scalar::from_repr(FieldBytes::clone_from_slice(&b)).unwrap(),
523 )
524 })
525 .unwrap(),
526 hex::decode("cb15c36dfe7b15c253e3f9fde1fd9ccfbd75839ff6dccca49700cb831dc5802e")
527 .map(|b| {
528 IdentifierPrimeField::from(
529 Scalar::from_repr(FieldBytes::clone_from_slice(&b)).unwrap(),
530 )
531 })
532 .unwrap(),
533 hex::decode("bb3a92d716f6a8d94d82295fd120b23d42ec8543a405ecd82e519ab0fe4ef965")
534 .map(|b| {
535 IdentifierPrimeField::from(
536 Scalar::from_repr(FieldBytes::clone_from_slice(&b)).unwrap(),
537 )
538 })
539 .unwrap(),
540 hex::decode("a0fff4c9e992f0d1acc8bc90fe6ae31dee280a0175a028a6333dde56de2121ec")
541 .map(|b| {
542 IdentifierPrimeField::from(
543 Scalar::from_repr(FieldBytes::clone_from_slice(&b)).unwrap(),
544 )
545 })
546 .unwrap(),
547 ];
548 let mut last_i = 0;
549 for (i, id) in collection.iter().enumerate() {
550 assert_eq!(id, expected[i]);
551 last_i = i;
552 }
553 assert_eq!(last_i, expected.len() - 1);
554 }
555
556 #[cfg(any(feature = "alloc", feature = "std"))]
557 #[test]
558 fn test_empty_list_and_sequential_number_generator() {
559 let list: [IdentifierPrimeField<Scalar>; 0] = [];
560 let generators = [
561 ParticipantIdGeneratorType::list(&list),
562 ParticipantIdGeneratorType::sequential(None, None, NonZeroUsize::new(5).unwrap()),
563 ];
564 let collection = ParticipantIdGeneratorCollection::from(&generators);
565 let list: Vec<_> = collection.iter().collect();
566 assert_eq!(list.len(), 5);
567 assert_eq!(list[0], IdentifierPrimeField::from(Scalar::from(1u64)));
568 assert_eq!(list[1], IdentifierPrimeField::from(Scalar::from(2u64)));
569 assert_eq!(list[2], IdentifierPrimeField::from(Scalar::from(3u64)));
570 assert_eq!(list[3], IdentifierPrimeField::from(Scalar::from(4u64)));
571 assert_eq!(list[4], IdentifierPrimeField::from(Scalar::from(5u64)));
572 }
573}