1use super::xof::XofTurboShake128;
31#[cfg(feature = "experimental")]
32use super::AggregatorWithNoise;
33use crate::codec::{CodecError, Decode, Encode, ParameterizedDecode};
34#[cfg(feature = "experimental")]
35use crate::dp::DifferentialPrivacyStrategy;
36use crate::field::{decode_fieldvec, FftFriendlyFieldElement, FieldElement};
37use crate::field::{Field128, Field64};
38#[cfg(feature = "multithreaded")]
39use crate::flp::gadgets::ParallelSumMultithreaded;
40#[cfg(feature = "experimental")]
41use crate::flp::gadgets::PolyEval;
42use crate::flp::gadgets::{Mul, ParallelSum};
43#[cfg(feature = "experimental")]
44use crate::flp::types::fixedpoint_l2::{
45 compatible_float::CompatibleFloat, FixedPointBoundedL2VecSum,
46};
47use crate::flp::types::{Average, Count, Histogram, Sum, SumVec};
48use crate::flp::Type;
49#[cfg(feature = "experimental")]
50use crate::flp::TypeWithNoise;
51use crate::prng::Prng;
52use crate::vdaf::xof::{IntoFieldVec, Seed, Xof};
53use crate::vdaf::{
54 Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, PrepareTransition,
55 Share, ShareDecodingParameter, Vdaf, VdafError,
56};
57#[cfg(feature = "experimental")]
58use fixed::traits::Fixed;
59use std::borrow::Cow;
60use std::convert::TryFrom;
61use std::fmt::Debug;
62use std::io::Cursor;
63use std::iter::{self, IntoIterator};
64use std::marker::PhantomData;
65use subtle::{Choice, ConstantTimeEq};
66
67const DST_MEASUREMENT_SHARE: u16 = 1;
68const DST_PROOF_SHARE: u16 = 2;
69const DST_JOINT_RANDOMNESS: u16 = 3;
70const DST_PROVE_RANDOMNESS: u16 = 4;
71const DST_QUERY_RANDOMNESS: u16 = 5;
72const DST_JOINT_RAND_SEED: u16 = 6;
73const DST_JOINT_RAND_PART: u16 = 7;
74
75pub type Prio3Count = Prio3<Count<Field64>, XofTurboShake128, 16>;
77
78impl Prio3Count {
79 pub fn new_count(num_aggregators: u8) -> Result<Self, VdafError> {
81 Prio3::new(num_aggregators, 1, 0x00000000, Count::new())
82 }
83}
84
85pub type Prio3SumVec =
88 Prio3<SumVec<Field128, ParallelSum<Field128, Mul<Field128>>>, XofTurboShake128, 16>;
89
90impl Prio3SumVec {
91 pub fn new_sum_vec(
95 num_aggregators: u8,
96 bits: usize,
97 len: usize,
98 chunk_length: usize,
99 ) -> Result<Self, VdafError> {
100 Prio3::new(
101 num_aggregators,
102 1,
103 0x00000002,
104 SumVec::new(bits, len, chunk_length)?,
105 )
106 }
107}
108
109#[cfg(feature = "multithreaded")]
112#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
113pub type Prio3SumVecMultithreaded = Prio3<
114 SumVec<Field128, ParallelSumMultithreaded<Field128, Mul<Field128>>>,
115 XofTurboShake128,
116 16,
117>;
118
119#[cfg(feature = "multithreaded")]
120impl Prio3SumVecMultithreaded {
121 pub fn new_sum_vec_multithreaded(
125 num_aggregators: u8,
126 bits: usize,
127 len: usize,
128 chunk_length: usize,
129 ) -> Result<Self, VdafError> {
130 Prio3::new(
131 num_aggregators,
132 1,
133 0x00000002,
134 SumVec::new(bits, len, chunk_length)?,
135 )
136 }
137}
138
139pub type Prio3Sum = Prio3<Sum<Field128>, XofTurboShake128, 16>;
142
143impl Prio3Sum {
144 pub fn new_sum(num_aggregators: u8, bits: usize) -> Result<Self, VdafError> {
147 if bits > 64 {
148 return Err(VdafError::Uncategorized(format!(
149 "bit length ({bits}) exceeds limit for aggregate type (64)"
150 )));
151 }
152
153 Prio3::new(num_aggregators, 1, 0x00000001, Sum::new(bits)?)
154 }
155}
156
157#[cfg(feature = "experimental")]
170#[cfg_attr(docsrs, doc(cfg(feature = "experimental")))]
171pub type Prio3FixedPointBoundedL2VecSum<Fx> = Prio3<
172 FixedPointBoundedL2VecSum<
173 Fx,
174 ParallelSum<Field128, PolyEval<Field128>>,
175 ParallelSum<Field128, Mul<Field128>>,
176 >,
177 XofTurboShake128,
178 16,
179>;
180
181#[cfg(feature = "experimental")]
182impl<Fx: Fixed + CompatibleFloat> Prio3FixedPointBoundedL2VecSum<Fx> {
183 pub fn new_fixedpoint_boundedl2_vec_sum(
186 num_aggregators: u8,
187 entries: usize,
188 ) -> Result<Self, VdafError> {
189 check_num_aggregators(num_aggregators)?;
190 Prio3::new(
191 num_aggregators,
192 1,
193 0xFFFF0000,
194 FixedPointBoundedL2VecSum::new(entries)?,
195 )
196 }
197}
198
199#[cfg(all(feature = "experimental", feature = "multithreaded"))]
203#[cfg_attr(
204 docsrs,
205 doc(cfg(all(feature = "experimental", feature = "multithreaded")))
206)]
207pub type Prio3FixedPointBoundedL2VecSumMultithreaded<Fx> = Prio3<
208 FixedPointBoundedL2VecSum<
209 Fx,
210 ParallelSumMultithreaded<Field128, PolyEval<Field128>>,
211 ParallelSumMultithreaded<Field128, Mul<Field128>>,
212 >,
213 XofTurboShake128,
214 16,
215>;
216
217#[cfg(all(feature = "experimental", feature = "multithreaded"))]
218impl<Fx: Fixed + CompatibleFloat> Prio3FixedPointBoundedL2VecSumMultithreaded<Fx> {
219 pub fn new_fixedpoint_boundedl2_vec_sum_multithreaded(
222 num_aggregators: u8,
223 entries: usize,
224 ) -> Result<Self, VdafError> {
225 check_num_aggregators(num_aggregators)?;
226 Prio3::new(
227 num_aggregators,
228 1,
229 0xFFFF0000,
230 FixedPointBoundedL2VecSum::new(entries)?,
231 )
232 }
233}
234
235pub type Prio3Histogram =
238 Prio3<Histogram<Field128, ParallelSum<Field128, Mul<Field128>>>, XofTurboShake128, 16>;
239
240impl Prio3Histogram {
241 pub fn new_histogram(
244 num_aggregators: u8,
245 length: usize,
246 chunk_length: usize,
247 ) -> Result<Self, VdafError> {
248 Prio3::new(
249 num_aggregators,
250 1,
251 0x00000003,
252 Histogram::new(length, chunk_length)?,
253 )
254 }
255}
256
257#[cfg(feature = "multithreaded")]
260#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
261pub type Prio3HistogramMultithreaded = Prio3<
262 Histogram<Field128, ParallelSumMultithreaded<Field128, Mul<Field128>>>,
263 XofTurboShake128,
264 16,
265>;
266
267#[cfg(feature = "multithreaded")]
268impl Prio3HistogramMultithreaded {
269 pub fn new_histogram_multithreaded(
272 num_aggregators: u8,
273 length: usize,
274 chunk_length: usize,
275 ) -> Result<Self, VdafError> {
276 Prio3::new(
277 num_aggregators,
278 1,
279 0x00000003,
280 Histogram::new(length, chunk_length)?,
281 )
282 }
283}
284
285pub type Prio3Average = Prio3<Average<Field128>, XofTurboShake128, 16>;
288
289impl Prio3Average {
290 pub fn new_average(num_aggregators: u8, bits: usize) -> Result<Self, VdafError> {
293 check_num_aggregators(num_aggregators)?;
294
295 if bits > 64 {
296 return Err(VdafError::Uncategorized(format!(
297 "bit length ({bits}) exceeds limit for aggregate type (64)"
298 )));
299 }
300
301 Ok(Prio3 {
302 num_aggregators,
303 num_proofs: 1,
304 algorithm_id: 0xFFFF0000,
305 typ: Average::new(bits)?,
306 phantom: PhantomData,
307 })
308 }
309}
310
311#[derive(Clone, Debug)]
375pub struct Prio3<T, P, const SEED_SIZE: usize>
376where
377 T: Type,
378 P: Xof<SEED_SIZE>,
379{
380 num_aggregators: u8,
381 num_proofs: u8,
382 algorithm_id: u32,
383 typ: T,
384 phantom: PhantomData<P>,
385}
386
387impl<T, P, const SEED_SIZE: usize> Prio3<T, P, SEED_SIZE>
388where
389 T: Type,
390 P: Xof<SEED_SIZE>,
391{
392 pub fn new(
395 num_aggregators: u8,
396 num_proofs: u8,
397 algorithm_id: u32,
398 typ: T,
399 ) -> Result<Self, VdafError> {
400 check_num_aggregators(num_aggregators)?;
401 if num_proofs == 0 {
402 return Err(VdafError::Uncategorized(
403 "num_proofs must be at least 1".to_string(),
404 ));
405 }
406
407 Ok(Self {
408 num_aggregators,
409 num_proofs,
410 algorithm_id,
411 typ,
412 phantom: PhantomData,
413 })
414 }
415
416 pub fn output_len(&self) -> usize {
418 self.typ.output_len()
419 }
420
421 pub fn verifier_len(&self) -> usize {
423 self.typ.verifier_len()
424 }
425
426 #[inline]
427 fn num_proofs(&self) -> usize {
428 self.num_proofs.into()
429 }
430
431 fn derive_prove_rands(&self, prove_rand_seed: &Seed<SEED_SIZE>) -> Vec<T::Field> {
432 P::seed_stream(
433 prove_rand_seed,
434 &self.domain_separation_tag(DST_PROVE_RANDOMNESS),
435 &[self.num_proofs],
436 )
437 .into_field_vec(self.typ.prove_rand_len() * self.num_proofs())
438 }
439
440 fn derive_joint_rand_seed<'a>(
441 &self,
442 joint_rand_parts: impl Iterator<Item = &'a Seed<SEED_SIZE>>,
443 ) -> Seed<SEED_SIZE> {
444 let mut xof = P::init(
445 &[0; SEED_SIZE],
446 &self.domain_separation_tag(DST_JOINT_RAND_SEED),
447 );
448 for part in joint_rand_parts {
449 xof.update(part.as_ref());
450 }
451 xof.into_seed()
452 }
453
454 fn derive_joint_rands<'a>(
455 &self,
456 joint_rand_parts: impl Iterator<Item = &'a Seed<SEED_SIZE>>,
457 ) -> (Seed<SEED_SIZE>, Vec<T::Field>) {
458 let joint_rand_seed = self.derive_joint_rand_seed(joint_rand_parts);
459 let joint_rands = P::seed_stream(
460 &joint_rand_seed,
461 &self.domain_separation_tag(DST_JOINT_RANDOMNESS),
462 &[self.num_proofs],
463 )
464 .into_field_vec(self.typ.joint_rand_len() * self.num_proofs());
465
466 (joint_rand_seed, joint_rands)
467 }
468
469 fn derive_helper_proofs_share(
470 &self,
471 proofs_share_seed: &Seed<SEED_SIZE>,
472 agg_id: u8,
473 ) -> Prng<T::Field, P::SeedStream> {
474 Prng::from_seed_stream(P::seed_stream(
475 proofs_share_seed,
476 &self.domain_separation_tag(DST_PROOF_SHARE),
477 &[self.num_proofs, agg_id],
478 ))
479 }
480
481 fn derive_query_rands(&self, verify_key: &[u8; SEED_SIZE], nonce: &[u8; 16]) -> Vec<T::Field> {
482 let mut xof = P::init(
483 verify_key,
484 &self.domain_separation_tag(DST_QUERY_RANDOMNESS),
485 );
486 xof.update(&[self.num_proofs]);
487 xof.update(nonce);
488 xof.into_seed_stream()
489 .into_field_vec(self.typ.query_rand_len() * self.num_proofs())
490 }
491
492 fn random_size(&self) -> usize {
493 if self.typ.joint_rand_len() == 0 {
494 (usize::from(self.num_aggregators - 1) * 2 + 1) * SEED_SIZE
497 } else {
498 (
499 usize::from(self.num_aggregators - 1) * 2
501 + 1
503 + usize::from(self.num_aggregators)
505 ) * SEED_SIZE
506 }
507 }
508
509 #[allow(clippy::type_complexity)]
510 pub(crate) fn shard_with_random<const N: usize>(
511 &self,
512 measurement: &T::Measurement,
513 nonce: &[u8; N],
514 random: &[u8],
515 ) -> Result<
516 (
517 Prio3PublicShare<SEED_SIZE>,
518 Vec<Prio3InputShare<T::Field, SEED_SIZE>>,
519 ),
520 VdafError,
521 > {
522 if random.len() != self.random_size() {
523 return Err(VdafError::Uncategorized(
524 "incorrect random input length".to_string(),
525 ));
526 }
527 let mut random_seeds = random.chunks_exact(SEED_SIZE);
528 let num_aggregators = self.num_aggregators;
529 let encoded_measurement = self.typ.encode_measurement(measurement)?;
530
531 let mut helper_shares = Vec::with_capacity(num_aggregators as usize - 1);
533 let mut helper_joint_rand_parts = if self.typ.joint_rand_len() > 0 {
534 Some(Vec::with_capacity(num_aggregators as usize - 1))
535 } else {
536 None
537 };
538 let mut leader_measurement_share = encoded_measurement.clone();
539 for agg_id in 1..num_aggregators {
540 let measurement_share_seed = random_seeds.next().unwrap().try_into().unwrap();
545 let proof_share_seed = random_seeds.next().unwrap().try_into().unwrap();
546 let measurement_share_prng: Prng<T::Field, _> = Prng::from_seed_stream(P::seed_stream(
547 &Seed(measurement_share_seed),
548 &self.domain_separation_tag(DST_MEASUREMENT_SHARE),
549 &[agg_id],
550 ));
551 let joint_rand_blind = if let Some(helper_joint_rand_parts) =
552 helper_joint_rand_parts.as_mut()
553 {
554 let joint_rand_blind = random_seeds.next().unwrap().try_into().unwrap();
555 let mut joint_rand_part_xof = P::init(
556 &joint_rand_blind,
557 &self.domain_separation_tag(DST_JOINT_RAND_PART),
558 );
559 joint_rand_part_xof.update(&[agg_id]); joint_rand_part_xof.update(nonce);
561
562 let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE);
563 for (x, y) in leader_measurement_share
564 .iter_mut()
565 .zip(measurement_share_prng)
566 {
567 *x -= y;
568 y.encode(&mut encoding_buffer).map_err(|_| {
569 VdafError::Uncategorized("failed to encode measurement share".to_string())
570 })?;
571 joint_rand_part_xof.update(&encoding_buffer);
572 encoding_buffer.clear();
573 }
574
575 helper_joint_rand_parts.push(joint_rand_part_xof.into_seed());
576
577 Some(joint_rand_blind)
578 } else {
579 for (x, y) in leader_measurement_share
580 .iter_mut()
581 .zip(measurement_share_prng)
582 {
583 *x -= y;
584 }
585 None
586 };
587 let helper =
588 HelperShare::from_seeds(measurement_share_seed, proof_share_seed, joint_rand_blind);
589 helper_shares.push(helper);
590 }
591
592 let mut leader_blind_opt = None;
593 let public_share = Prio3PublicShare {
594 joint_rand_parts: helper_joint_rand_parts
595 .as_ref()
596 .map(
597 |helper_joint_rand_parts| -> Result<Vec<Seed<SEED_SIZE>>, VdafError> {
598 let leader_blind_bytes = random_seeds.next().unwrap().try_into().unwrap();
599 let leader_blind = Seed::from_bytes(leader_blind_bytes);
600
601 let mut joint_rand_part_xof = P::init(
602 leader_blind.as_ref(),
603 &self.domain_separation_tag(DST_JOINT_RAND_PART),
604 );
605 joint_rand_part_xof.update(&[0]); joint_rand_part_xof.update(nonce);
607 let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE);
608 for x in leader_measurement_share.iter() {
609 x.encode(&mut encoding_buffer).map_err(|_| {
610 VdafError::Uncategorized(
611 "failed to encode measurement share".to_string(),
612 )
613 })?;
614 joint_rand_part_xof.update(&encoding_buffer);
615 encoding_buffer.clear();
616 }
617 leader_blind_opt = Some(leader_blind);
618
619 let leader_joint_rand_seed_part = joint_rand_part_xof.into_seed();
620
621 let mut vec = Vec::with_capacity(self.num_aggregators());
622 vec.push(leader_joint_rand_seed_part);
623 vec.extend(helper_joint_rand_parts.iter().cloned());
624 Ok(vec)
625 },
626 )
627 .transpose()?,
628 };
629
630 let joint_rands = public_share
632 .joint_rand_parts
633 .as_ref()
634 .map(|joint_rand_parts| self.derive_joint_rands(joint_rand_parts.iter()).1)
635 .unwrap_or_default();
636
637 let prove_rands = self.derive_prove_rands(&Seed::from_bytes(
639 random_seeds.next().unwrap().try_into().unwrap(),
640 ));
641 let mut leader_proofs_share = Vec::with_capacity(self.typ.proof_len() * self.num_proofs());
642 for p in 0..self.num_proofs() {
643 let prove_rand =
644 &prove_rands[p * self.typ.prove_rand_len()..(p + 1) * self.typ.prove_rand_len()];
645 let joint_rand =
646 &joint_rands[p * self.typ.joint_rand_len()..(p + 1) * self.typ.joint_rand_len()];
647
648 leader_proofs_share.append(&mut self.typ.prove(
649 &encoded_measurement,
650 prove_rand,
651 joint_rand,
652 )?);
653 }
654
655 for (j, helper) in helper_shares.iter_mut().enumerate() {
657 for (x, y) in
658 leader_proofs_share
659 .iter_mut()
660 .zip(self.derive_helper_proofs_share(
661 &helper.proofs_share,
662 u8::try_from(j).unwrap() + 1,
663 ))
664 .take(self.typ.proof_len() * self.num_proofs())
665 {
666 *x -= y;
667 }
668 }
669
670 let mut out = Vec::with_capacity(num_aggregators as usize);
672 out.push(Prio3InputShare {
673 measurement_share: Share::Leader(leader_measurement_share),
674 proofs_share: Share::Leader(leader_proofs_share),
675 joint_rand_blind: leader_blind_opt,
676 });
677
678 for helper in helper_shares.into_iter() {
679 out.push(Prio3InputShare {
680 measurement_share: Share::Helper(helper.measurement_share),
681 proofs_share: Share::Helper(helper.proofs_share),
682 joint_rand_blind: helper.joint_rand_blind,
683 });
684 }
685
686 Ok((public_share, out))
687 }
688
689 fn role_try_from(&self, agg_id: usize) -> Result<u8, VdafError> {
690 if agg_id >= self.num_aggregators as usize {
691 return Err(VdafError::Uncategorized("unexpected aggregator id".into()));
692 }
693 Ok(u8::try_from(agg_id).unwrap())
694 }
695}
696
697impl<T, P, const SEED_SIZE: usize> Vdaf for Prio3<T, P, SEED_SIZE>
698where
699 T: Type,
700 P: Xof<SEED_SIZE>,
701{
702 type Measurement = T::Measurement;
703 type AggregateResult = T::AggregateResult;
704 type AggregationParam = ();
705 type PublicShare = Prio3PublicShare<SEED_SIZE>;
706 type InputShare = Prio3InputShare<T::Field, SEED_SIZE>;
707 type OutputShare = OutputShare<T::Field>;
708 type AggregateShare = AggregateShare<T::Field>;
709
710 fn algorithm_id(&self) -> u32 {
711 self.algorithm_id
712 }
713
714 fn num_aggregators(&self) -> usize {
715 self.num_aggregators as usize
716 }
717}
718
719#[derive(Clone, Debug)]
721pub struct Prio3PublicShare<const SEED_SIZE: usize> {
722 joint_rand_parts: Option<Vec<Seed<SEED_SIZE>>>,
724}
725
726impl<const SEED_SIZE: usize> Encode for Prio3PublicShare<SEED_SIZE> {
727 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
728 if let Some(joint_rand_parts) = self.joint_rand_parts.as_ref() {
729 for part in joint_rand_parts.iter() {
730 part.encode(bytes)?;
731 }
732 }
733 Ok(())
734 }
735
736 fn encoded_len(&self) -> Option<usize> {
737 if let Some(joint_rand_parts) = self.joint_rand_parts.as_ref() {
738 Some(SEED_SIZE * joint_rand_parts.len())
740 } else {
741 Some(0)
742 }
743 }
744}
745
746impl<const SEED_SIZE: usize> PartialEq for Prio3PublicShare<SEED_SIZE> {
747 fn eq(&self, other: &Self) -> bool {
748 self.ct_eq(other).into()
749 }
750}
751
752impl<const SEED_SIZE: usize> Eq for Prio3PublicShare<SEED_SIZE> {}
753
754impl<const SEED_SIZE: usize> ConstantTimeEq for Prio3PublicShare<SEED_SIZE> {
755 fn ct_eq(&self, other: &Self) -> Choice {
756 option_ct_eq(
758 self.joint_rand_parts.as_deref(),
759 other.joint_rand_parts.as_deref(),
760 )
761 }
762}
763
764impl<T, P, const SEED_SIZE: usize> ParameterizedDecode<Prio3<T, P, SEED_SIZE>>
765 for Prio3PublicShare<SEED_SIZE>
766where
767 T: Type,
768 P: Xof<SEED_SIZE>,
769{
770 fn decode_with_param(
771 decoding_parameter: &Prio3<T, P, SEED_SIZE>,
772 bytes: &mut Cursor<&[u8]>,
773 ) -> Result<Self, CodecError> {
774 if decoding_parameter.typ.joint_rand_len() > 0 {
775 let joint_rand_parts = iter::repeat_with(|| Seed::<SEED_SIZE>::decode(bytes))
776 .take(decoding_parameter.num_aggregators.into())
777 .collect::<Result<Vec<_>, _>>()?;
778 Ok(Self {
779 joint_rand_parts: Some(joint_rand_parts),
780 })
781 } else {
782 Ok(Self {
783 joint_rand_parts: None,
784 })
785 }
786 }
787}
788
789#[derive(Clone, Debug)]
791pub struct Prio3InputShare<F, const SEED_SIZE: usize> {
792 measurement_share: Share<F, SEED_SIZE>,
794
795 proofs_share: Share<F, SEED_SIZE>,
797
798 joint_rand_blind: Option<Seed<SEED_SIZE>>,
801}
802
803impl<F: ConstantTimeEq, const SEED_SIZE: usize> PartialEq for Prio3InputShare<F, SEED_SIZE> {
804 fn eq(&self, other: &Self) -> bool {
805 self.ct_eq(other).into()
806 }
807}
808
809impl<F: ConstantTimeEq, const SEED_SIZE: usize> Eq for Prio3InputShare<F, SEED_SIZE> {}
810
811impl<F: ConstantTimeEq, const SEED_SIZE: usize> ConstantTimeEq for Prio3InputShare<F, SEED_SIZE> {
812 fn ct_eq(&self, other: &Self) -> Choice {
813 option_ct_eq(
815 self.joint_rand_blind.as_ref(),
816 other.joint_rand_blind.as_ref(),
817 ) & self.measurement_share.ct_eq(&other.measurement_share)
818 & self.proofs_share.ct_eq(&other.proofs_share)
819 }
820}
821
822impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize> Encode for Prio3InputShare<F, SEED_SIZE> {
823 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
824 if matches!(
825 (&self.measurement_share, &self.proofs_share),
826 (Share::Leader(_), Share::Helper(_)) | (Share::Helper(_), Share::Leader(_))
827 ) {
828 panic!("tried to encode input share with ambiguous encoding")
829 }
830
831 self.measurement_share.encode(bytes)?;
832 self.proofs_share.encode(bytes)?;
833 if let Some(ref blind) = self.joint_rand_blind {
834 blind.encode(bytes)?;
835 }
836 Ok(())
837 }
838
839 fn encoded_len(&self) -> Option<usize> {
840 let mut len = self.measurement_share.encoded_len()? + self.proofs_share.encoded_len()?;
841 if let Some(ref blind) = self.joint_rand_blind {
842 len += blind.encoded_len()?;
843 }
844 Some(len)
845 }
846}
847
848impl<'a, T, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Prio3<T, P, SEED_SIZE>, usize)>
849 for Prio3InputShare<T::Field, SEED_SIZE>
850where
851 T: Type,
852 P: Xof<SEED_SIZE>,
853{
854 fn decode_with_param(
855 (prio3, agg_id): &(&'a Prio3<T, P, SEED_SIZE>, usize),
856 bytes: &mut Cursor<&[u8]>,
857 ) -> Result<Self, CodecError> {
858 let agg_id = prio3
859 .role_try_from(*agg_id)
860 .map_err(|e| CodecError::Other(Box::new(e)))?;
861 let (input_decoder, proof_decoder) = if agg_id == 0 {
862 (
863 ShareDecodingParameter::Leader(prio3.typ.input_len()),
864 ShareDecodingParameter::Leader(prio3.typ.proof_len() * prio3.num_proofs()),
865 )
866 } else {
867 (
868 ShareDecodingParameter::Helper,
869 ShareDecodingParameter::Helper,
870 )
871 };
872
873 let measurement_share = Share::decode_with_param(&input_decoder, bytes)?;
874 let proofs_share = Share::decode_with_param(&proof_decoder, bytes)?;
875 let joint_rand_blind = if prio3.typ.joint_rand_len() > 0 {
876 let blind = Seed::decode(bytes)?;
877 Some(blind)
878 } else {
879 None
880 };
881
882 Ok(Prio3InputShare {
883 measurement_share,
884 proofs_share,
885 joint_rand_blind,
886 })
887 }
888}
889
890#[derive(Clone, Debug)]
891pub struct Prio3PrepareShare<F, const SEED_SIZE: usize> {
893 verifiers: Vec<F>,
895
896 joint_rand_part: Option<Seed<SEED_SIZE>>,
898}
899
900impl<F: ConstantTimeEq, const SEED_SIZE: usize> PartialEq for Prio3PrepareShare<F, SEED_SIZE> {
901 fn eq(&self, other: &Self) -> bool {
902 self.ct_eq(other).into()
903 }
904}
905
906impl<F: ConstantTimeEq, const SEED_SIZE: usize> Eq for Prio3PrepareShare<F, SEED_SIZE> {}
907
908impl<F: ConstantTimeEq, const SEED_SIZE: usize> ConstantTimeEq for Prio3PrepareShare<F, SEED_SIZE> {
909 fn ct_eq(&self, other: &Self) -> Choice {
910 option_ct_eq(
912 self.joint_rand_part.as_ref(),
913 other.joint_rand_part.as_ref(),
914 ) & self.verifiers.ct_eq(&other.verifiers)
915 }
916}
917
918impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize> Encode
919 for Prio3PrepareShare<F, SEED_SIZE>
920{
921 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
922 for x in &self.verifiers {
923 x.encode(bytes)?;
924 }
925 if let Some(ref seed) = self.joint_rand_part {
926 seed.encode(bytes)?;
927 }
928 Ok(())
929 }
930
931 fn encoded_len(&self) -> Option<usize> {
932 let mut len = F::ENCODED_SIZE * self.verifiers.len();
934 if let Some(ref seed) = self.joint_rand_part {
935 len += seed.encoded_len()?;
936 }
937 Some(len)
938 }
939}
940
941impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize>
942 ParameterizedDecode<Prio3PrepareState<F, SEED_SIZE>> for Prio3PrepareShare<F, SEED_SIZE>
943{
944 fn decode_with_param(
945 decoding_parameter: &Prio3PrepareState<F, SEED_SIZE>,
946 bytes: &mut Cursor<&[u8]>,
947 ) -> Result<Self, CodecError> {
948 let mut verifiers = Vec::with_capacity(decoding_parameter.verifiers_len);
949 for _ in 0..decoding_parameter.verifiers_len {
950 verifiers.push(F::decode(bytes)?);
951 }
952
953 let joint_rand_part = if decoding_parameter.joint_rand_seed.is_some() {
954 Some(Seed::decode(bytes)?)
955 } else {
956 None
957 };
958
959 Ok(Prio3PrepareShare {
960 verifiers,
961 joint_rand_part,
962 })
963 }
964}
965
966#[derive(Clone, Debug)]
967pub struct Prio3PrepareMessage<const SEED_SIZE: usize> {
969 joint_rand_seed: Option<Seed<SEED_SIZE>>,
971}
972
973impl<const SEED_SIZE: usize> PartialEq for Prio3PrepareMessage<SEED_SIZE> {
974 fn eq(&self, other: &Self) -> bool {
975 self.ct_eq(other).into()
976 }
977}
978
979impl<const SEED_SIZE: usize> Eq for Prio3PrepareMessage<SEED_SIZE> {}
980
981impl<const SEED_SIZE: usize> ConstantTimeEq for Prio3PrepareMessage<SEED_SIZE> {
982 fn ct_eq(&self, other: &Self) -> Choice {
983 option_ct_eq(
985 self.joint_rand_seed.as_ref(),
986 other.joint_rand_seed.as_ref(),
987 )
988 }
989}
990
991impl<const SEED_SIZE: usize> Encode for Prio3PrepareMessage<SEED_SIZE> {
992 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
993 if let Some(ref seed) = self.joint_rand_seed {
994 seed.encode(bytes)?;
995 }
996 Ok(())
997 }
998
999 fn encoded_len(&self) -> Option<usize> {
1000 if let Some(ref seed) = self.joint_rand_seed {
1001 seed.encoded_len()
1002 } else {
1003 Some(0)
1004 }
1005 }
1006}
1007
1008impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize>
1009 ParameterizedDecode<Prio3PrepareState<F, SEED_SIZE>> for Prio3PrepareMessage<SEED_SIZE>
1010{
1011 fn decode_with_param(
1012 decoding_parameter: &Prio3PrepareState<F, SEED_SIZE>,
1013 bytes: &mut Cursor<&[u8]>,
1014 ) -> Result<Self, CodecError> {
1015 let joint_rand_seed = if decoding_parameter.joint_rand_seed.is_some() {
1016 Some(Seed::decode(bytes)?)
1017 } else {
1018 None
1019 };
1020
1021 Ok(Prio3PrepareMessage { joint_rand_seed })
1022 }
1023}
1024
1025impl<T, P, const SEED_SIZE: usize> Client<16> for Prio3<T, P, SEED_SIZE>
1026where
1027 T: Type,
1028 P: Xof<SEED_SIZE>,
1029{
1030 #[allow(clippy::type_complexity)]
1031 fn shard(
1032 &self,
1033 measurement: &T::Measurement,
1034 nonce: &[u8; 16],
1035 ) -> Result<(Self::PublicShare, Vec<Prio3InputShare<T::Field, SEED_SIZE>>), VdafError> {
1036 let mut random = vec![0u8; self.random_size()];
1037 getrandom::getrandom(&mut random)?;
1038 self.shard_with_random(measurement, nonce, &random)
1039 }
1040}
1041
1042#[derive(Clone)]
1044pub struct Prio3PrepareState<F, const SEED_SIZE: usize> {
1045 measurement_share: Share<F, SEED_SIZE>,
1046 joint_rand_seed: Option<Seed<SEED_SIZE>>,
1047 agg_id: u8,
1048 verifiers_len: usize,
1049}
1050
1051impl<F: ConstantTimeEq, const SEED_SIZE: usize> PartialEq for Prio3PrepareState<F, SEED_SIZE> {
1052 fn eq(&self, other: &Self) -> bool {
1053 self.ct_eq(other).into()
1054 }
1055}
1056
1057impl<F: ConstantTimeEq, const SEED_SIZE: usize> Eq for Prio3PrepareState<F, SEED_SIZE> {}
1058
1059impl<F: ConstantTimeEq, const SEED_SIZE: usize> ConstantTimeEq for Prio3PrepareState<F, SEED_SIZE> {
1060 fn ct_eq(&self, other: &Self) -> Choice {
1061 if self.agg_id != other.agg_id || self.verifiers_len != other.verifiers_len {
1064 return Choice::from(0);
1065 }
1066
1067 option_ct_eq(
1068 self.joint_rand_seed.as_ref(),
1069 other.joint_rand_seed.as_ref(),
1070 ) & self.measurement_share.ct_eq(&other.measurement_share)
1071 }
1072}
1073
1074impl<F, const SEED_SIZE: usize> Debug for Prio3PrepareState<F, SEED_SIZE> {
1075 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1076 f.debug_struct("Prio3PrepareState")
1077 .field("measurement_share", &"[redacted]")
1078 .field(
1079 "joint_rand_seed",
1080 match self.joint_rand_seed {
1081 Some(_) => &"Some([redacted])",
1082 None => &"None",
1083 },
1084 )
1085 .field("agg_id", &self.agg_id)
1086 .field("verifiers_len", &self.verifiers_len)
1087 .finish()
1088 }
1089}
1090
1091impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize> Encode
1092 for Prio3PrepareState<F, SEED_SIZE>
1093{
1094 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
1096 self.measurement_share.encode(bytes)?;
1097 if let Some(ref seed) = self.joint_rand_seed {
1098 seed.encode(bytes)?;
1099 }
1100 Ok(())
1101 }
1102
1103 fn encoded_len(&self) -> Option<usize> {
1104 let mut len = self.measurement_share.encoded_len()?;
1105 if let Some(ref seed) = self.joint_rand_seed {
1106 len += seed.encoded_len()?;
1107 }
1108 Some(len)
1109 }
1110}
1111
1112impl<'a, T, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Prio3<T, P, SEED_SIZE>, usize)>
1113 for Prio3PrepareState<T::Field, SEED_SIZE>
1114where
1115 T: Type,
1116 P: Xof<SEED_SIZE>,
1117{
1118 fn decode_with_param(
1119 (prio3, agg_id): &(&'a Prio3<T, P, SEED_SIZE>, usize),
1120 bytes: &mut Cursor<&[u8]>,
1121 ) -> Result<Self, CodecError> {
1122 let agg_id = prio3
1123 .role_try_from(*agg_id)
1124 .map_err(|e| CodecError::Other(Box::new(e)))?;
1125
1126 let share_decoder = if agg_id == 0 {
1127 ShareDecodingParameter::Leader(prio3.typ.input_len())
1128 } else {
1129 ShareDecodingParameter::Helper
1130 };
1131 let measurement_share = Share::decode_with_param(&share_decoder, bytes)?;
1132
1133 let joint_rand_seed = if prio3.typ.joint_rand_len() > 0 {
1134 Some(Seed::decode(bytes)?)
1135 } else {
1136 None
1137 };
1138
1139 Ok(Self {
1140 measurement_share,
1141 joint_rand_seed,
1142 agg_id,
1143 verifiers_len: prio3.typ.verifier_len() * prio3.num_proofs(),
1144 })
1145 }
1146}
1147
1148impl<T, P, const SEED_SIZE: usize> Aggregator<SEED_SIZE, 16> for Prio3<T, P, SEED_SIZE>
1149where
1150 T: Type,
1151 P: Xof<SEED_SIZE>,
1152{
1153 type PrepareState = Prio3PrepareState<T::Field, SEED_SIZE>;
1154 type PrepareShare = Prio3PrepareShare<T::Field, SEED_SIZE>;
1155 type PrepareMessage = Prio3PrepareMessage<SEED_SIZE>;
1156
1157 #[allow(clippy::type_complexity)]
1160 fn prepare_init(
1161 &self,
1162 verify_key: &[u8; SEED_SIZE],
1163 agg_id: usize,
1164 _agg_param: &Self::AggregationParam,
1165 nonce: &[u8; 16],
1166 public_share: &Self::PublicShare,
1167 msg: &Prio3InputShare<T::Field, SEED_SIZE>,
1168 ) -> Result<
1169 (
1170 Prio3PrepareState<T::Field, SEED_SIZE>,
1171 Prio3PrepareShare<T::Field, SEED_SIZE>,
1172 ),
1173 VdafError,
1174 > {
1175 let agg_id = self.role_try_from(agg_id)?;
1176
1177 let measurement_share = match msg.measurement_share {
1178 Share::Leader(ref data) => Cow::Borrowed(data),
1179 Share::Helper(ref seed) => Cow::Owned(
1180 P::seed_stream(
1181 seed,
1182 &self.domain_separation_tag(DST_MEASUREMENT_SHARE),
1183 &[agg_id],
1184 )
1185 .into_field_vec(self.typ.input_len()),
1186 ),
1187 };
1188
1189 let proofs_share = match msg.proofs_share {
1190 Share::Leader(ref data) => Cow::Borrowed(data),
1191 Share::Helper(ref seed) => Cow::Owned(
1192 self.derive_helper_proofs_share(seed, agg_id)
1193 .take(self.typ.proof_len() * self.num_proofs())
1194 .collect::<Vec<_>>(),
1195 ),
1196 };
1197
1198 let (joint_rand_seed, joint_rand_part, joint_rands) = if self.typ.joint_rand_len() > 0 {
1200 let mut joint_rand_part_xof = P::init(
1201 msg.joint_rand_blind.as_ref().unwrap().as_ref(),
1202 &self.domain_separation_tag(DST_JOINT_RAND_PART),
1203 );
1204 joint_rand_part_xof.update(&[agg_id]);
1205 joint_rand_part_xof.update(nonce);
1206 let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE);
1207 for x in measurement_share.iter() {
1208 x.encode(&mut encoding_buffer).map_err(|_| {
1209 VdafError::Uncategorized("failed to encode measurement share".to_string())
1210 })?;
1211 joint_rand_part_xof.update(&encoding_buffer);
1212 encoding_buffer.clear();
1213 }
1214 let own_joint_rand_part = joint_rand_part_xof.into_seed();
1215
1216 let corrected_joint_rand_parts = public_share
1224 .joint_rand_parts
1225 .iter()
1226 .flatten()
1227 .take(agg_id as usize)
1228 .chain(iter::once(&own_joint_rand_part))
1229 .chain(
1230 public_share
1231 .joint_rand_parts
1232 .iter()
1233 .flatten()
1234 .skip(agg_id as usize + 1),
1235 );
1236
1237 let (joint_rand_seed, joint_rands) =
1238 self.derive_joint_rands(corrected_joint_rand_parts);
1239
1240 (
1241 Some(joint_rand_seed),
1242 Some(own_joint_rand_part),
1243 joint_rands,
1244 )
1245 } else {
1246 (None, None, Vec::new())
1247 };
1248
1249 let query_rands = self.derive_query_rands(verify_key, nonce);
1251 let mut verifiers_share = Vec::with_capacity(self.typ.verifier_len() * self.num_proofs());
1252 for p in 0..self.num_proofs() {
1253 let query_rand =
1254 &query_rands[p * self.typ.query_rand_len()..(p + 1) * self.typ.query_rand_len()];
1255 let joint_rand =
1256 &joint_rands[p * self.typ.joint_rand_len()..(p + 1) * self.typ.joint_rand_len()];
1257 let proof_share =
1258 &proofs_share[p * self.typ.proof_len()..(p + 1) * self.typ.proof_len()];
1259
1260 verifiers_share.append(&mut self.typ.query(
1261 measurement_share.as_ref(),
1262 proof_share,
1263 query_rand,
1264 joint_rand,
1265 self.num_aggregators as usize,
1266 )?);
1267 }
1268
1269 Ok((
1270 Prio3PrepareState {
1271 measurement_share: msg.measurement_share.clone(),
1272 joint_rand_seed,
1273 agg_id,
1274 verifiers_len: verifiers_share.len(),
1275 },
1276 Prio3PrepareShare {
1277 verifiers: verifiers_share,
1278 joint_rand_part,
1279 },
1280 ))
1281 }
1282
1283 fn prepare_shares_to_prepare_message<
1284 M: IntoIterator<Item = Prio3PrepareShare<T::Field, SEED_SIZE>>,
1285 >(
1286 &self,
1287 _: &Self::AggregationParam,
1288 inputs: M,
1289 ) -> Result<Prio3PrepareMessage<SEED_SIZE>, VdafError> {
1290 let mut verifiers = vec![T::Field::zero(); self.typ.verifier_len() * self.num_proofs()];
1291 let mut joint_rand_parts = Vec::with_capacity(self.num_aggregators());
1292 let mut count = 0;
1293 for share in inputs.into_iter() {
1294 count += 1;
1295
1296 if share.verifiers.len() != verifiers.len() {
1297 return Err(VdafError::Uncategorized(format!(
1298 "unexpected verifier share length: got {}; want {}",
1299 share.verifiers.len(),
1300 verifiers.len(),
1301 )));
1302 }
1303
1304 if self.typ.joint_rand_len() > 0 {
1305 let joint_rand_seed_part = share.joint_rand_part.unwrap();
1306 joint_rand_parts.push(joint_rand_seed_part);
1307 }
1308
1309 for (x, y) in verifiers.iter_mut().zip(share.verifiers) {
1310 *x += y;
1311 }
1312 }
1313
1314 if count != self.num_aggregators {
1315 return Err(VdafError::Uncategorized(format!(
1316 "unexpected message count: got {}; want {}",
1317 count, self.num_aggregators,
1318 )));
1319 }
1320
1321 for verifier in verifiers.chunks(self.typ.verifier_len()) {
1323 if !self.typ.decide(verifier)? {
1324 return Err(VdafError::Uncategorized(
1325 "proof verifier check failed".into(),
1326 ));
1327 }
1328 }
1329
1330 let joint_rand_seed = if self.typ.joint_rand_len() > 0 {
1331 Some(self.derive_joint_rand_seed(joint_rand_parts.iter()))
1332 } else {
1333 None
1334 };
1335
1336 Ok(Prio3PrepareMessage { joint_rand_seed })
1337 }
1338
1339 fn prepare_next(
1340 &self,
1341 step: Prio3PrepareState<T::Field, SEED_SIZE>,
1342 msg: Prio3PrepareMessage<SEED_SIZE>,
1343 ) -> Result<PrepareTransition<Self, SEED_SIZE, 16>, VdafError> {
1344 if self.typ.joint_rand_len() > 0 {
1345 if step
1347 .joint_rand_seed
1348 .as_ref()
1349 .unwrap()
1350 .ct_ne(msg.joint_rand_seed.as_ref().unwrap())
1351 .into()
1352 {
1353 return Err(VdafError::Uncategorized(
1354 "joint randomness mismatch".to_string(),
1355 ));
1356 }
1357 }
1358
1359 let measurement_share = match step.measurement_share {
1361 Share::Leader(data) => data,
1362 Share::Helper(seed) => {
1363 let dst = self.domain_separation_tag(DST_MEASUREMENT_SHARE);
1364 P::seed_stream(&seed, &dst, &[step.agg_id]).into_field_vec(self.typ.input_len())
1365 }
1366 };
1367
1368 let output_share = match self.typ.truncate(measurement_share) {
1369 Ok(data) => OutputShare(data),
1370 Err(err) => {
1371 return Err(VdafError::from(err));
1372 }
1373 };
1374
1375 Ok(PrepareTransition::Finish(output_share))
1376 }
1377
1378 fn aggregate<It: IntoIterator<Item = OutputShare<T::Field>>>(
1380 &self,
1381 _agg_param: &(),
1382 output_shares: It,
1383 ) -> Result<AggregateShare<T::Field>, VdafError> {
1384 let mut agg_share = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]);
1385 for output_share in output_shares.into_iter() {
1386 agg_share.accumulate(&output_share)?;
1387 }
1388
1389 Ok(agg_share)
1390 }
1391}
1392
1393#[cfg(feature = "experimental")]
1394impl<T, P, S, const SEED_SIZE: usize> AggregatorWithNoise<SEED_SIZE, 16, S>
1395 for Prio3<T, P, SEED_SIZE>
1396where
1397 T: TypeWithNoise<S>,
1398 P: Xof<SEED_SIZE>,
1399 S: DifferentialPrivacyStrategy,
1400{
1401 fn add_noise_to_agg_share(
1402 &self,
1403 dp_strategy: &S,
1404 _agg_param: &Self::AggregationParam,
1405 agg_share: &mut Self::AggregateShare,
1406 num_measurements: usize,
1407 ) -> Result<(), VdafError> {
1408 self.typ
1409 .add_noise_to_result(dp_strategy, &mut agg_share.0, num_measurements)?;
1410 Ok(())
1411 }
1412}
1413
1414impl<T, P, const SEED_SIZE: usize> Collector for Prio3<T, P, SEED_SIZE>
1415where
1416 T: Type,
1417 P: Xof<SEED_SIZE>,
1418{
1419 fn unshard<It: IntoIterator<Item = AggregateShare<T::Field>>>(
1421 &self,
1422 _agg_param: &Self::AggregationParam,
1423 agg_shares: It,
1424 num_measurements: usize,
1425 ) -> Result<T::AggregateResult, VdafError> {
1426 let mut agg = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]);
1427 for agg_share in agg_shares.into_iter() {
1428 agg.merge(&agg_share)?;
1429 }
1430
1431 Ok(self.typ.decode_result(&agg.0, num_measurements)?)
1432 }
1433}
1434
1435#[derive(Clone)]
1436struct HelperShare<const SEED_SIZE: usize> {
1437 measurement_share: Seed<SEED_SIZE>,
1438 proofs_share: Seed<SEED_SIZE>,
1439 joint_rand_blind: Option<Seed<SEED_SIZE>>,
1440}
1441
1442impl<const SEED_SIZE: usize> HelperShare<SEED_SIZE> {
1443 fn from_seeds(
1444 measurement_share: [u8; SEED_SIZE],
1445 proof_share: [u8; SEED_SIZE],
1446 joint_rand_blind: Option<[u8; SEED_SIZE]>,
1447 ) -> Self {
1448 HelperShare {
1449 measurement_share: Seed::from_bytes(measurement_share),
1450 proofs_share: Seed::from_bytes(proof_share),
1451 joint_rand_blind: joint_rand_blind.map(Seed::from_bytes),
1452 }
1453 }
1454}
1455
1456fn check_num_aggregators(num_aggregators: u8) -> Result<(), VdafError> {
1457 if num_aggregators == 0 {
1458 return Err(VdafError::Uncategorized(format!(
1459 "at least one aggregator is required; got {num_aggregators}"
1460 )));
1461 } else if num_aggregators > 254 {
1462 return Err(VdafError::Uncategorized(format!(
1463 "number of aggregators must not exceed 254; got {num_aggregators}"
1464 )));
1465 }
1466
1467 Ok(())
1468}
1469
1470impl<'a, F, T, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Prio3<T, P, SEED_SIZE>, &'a ())>
1471 for OutputShare<F>
1472where
1473 F: FieldElement,
1474 T: Type,
1475 P: Xof<SEED_SIZE>,
1476{
1477 fn decode_with_param(
1478 (vdaf, _): &(&'a Prio3<T, P, SEED_SIZE>, &'a ()),
1479 bytes: &mut Cursor<&[u8]>,
1480 ) -> Result<Self, CodecError> {
1481 decode_fieldvec(vdaf.output_len(), bytes).map(Self)
1482 }
1483}
1484
1485impl<'a, F, T, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Prio3<T, P, SEED_SIZE>, &'a ())>
1486 for AggregateShare<F>
1487where
1488 F: FieldElement,
1489 T: Type,
1490 P: Xof<SEED_SIZE>,
1491{
1492 fn decode_with_param(
1493 (vdaf, _): &(&'a Prio3<T, P, SEED_SIZE>, &'a ()),
1494 bytes: &mut Cursor<&[u8]>,
1495 ) -> Result<Self, CodecError> {
1496 decode_fieldvec(vdaf.output_len(), bytes).map(Self)
1497 }
1498}
1499
1500#[inline]
1504fn option_ct_eq<T>(left: Option<&T>, right: Option<&T>) -> Choice
1505where
1506 T: ConstantTimeEq + ?Sized,
1507{
1508 match (left, right) {
1509 (Some(left), Some(right)) => left.ct_eq(right),
1510 (None, None) => Choice::from(1),
1511 _ => Choice::from(0),
1512 }
1513}
1514
1515fn ilog2(input: usize) -> u32 {
1523 if input == 0 {
1524 panic!("Tried to take the logarithm of zero");
1525 }
1526 (usize::BITS - 1) - input.leading_zeros()
1527}
1528
1529pub fn optimal_chunk_length(measurement_length: usize) -> usize {
1534 if measurement_length <= 1 {
1535 return 1;
1536 }
1537
1538 struct Candidate {
1540 gadget_calls: usize,
1541 chunk_length: usize,
1542 }
1543
1544 let max_log2 = ilog2(measurement_length + 1);
1545 let best_opt = (1..=max_log2)
1546 .rev()
1547 .map(|log2| {
1548 let gadget_calls = (1 << log2) - 1;
1549 let chunk_length = (measurement_length + gadget_calls - 1) / gadget_calls;
1550 Candidate {
1551 gadget_calls,
1552 chunk_length,
1553 }
1554 })
1555 .min_by_key(|candidate| {
1556 (candidate.chunk_length * 2)
1558 + 2 * ((1 + candidate.gadget_calls).next_power_of_two() - 1)
1559 });
1560 best_opt.unwrap().chunk_length
1564}
1565
1566#[cfg(test)]
1567mod tests {
1568 use super::*;
1569 #[cfg(feature = "experimental")]
1570 use crate::flp::gadgets::ParallelSumGadget;
1571 use crate::vdaf::{
1572 equality_comparison_test, fieldvec_roundtrip_test,
1573 test_utils::{run_vdaf, run_vdaf_prepare},
1574 };
1575 use assert_matches::assert_matches;
1576 #[cfg(feature = "experimental")]
1577 use fixed::{
1578 types::extra::{U15, U31, U63},
1579 FixedI16, FixedI32, FixedI64,
1580 };
1581 #[cfg(feature = "experimental")]
1582 use fixed_macro::fixed;
1583 use rand::prelude::*;
1584
1585 #[test]
1586 fn test_prio3_count() {
1587 let prio3 = Prio3::new_count(2).unwrap();
1588
1589 assert_eq!(
1590 run_vdaf(&prio3, &(), [true, false, false, true, true]).unwrap(),
1591 3
1592 );
1593
1594 let mut nonce = [0; 16];
1595 let mut verify_key = [0; 16];
1596 thread_rng().fill(&mut verify_key[..]);
1597 thread_rng().fill(&mut nonce[..]);
1598
1599 let (public_share, input_shares) = prio3.shard(&false, &nonce).unwrap();
1600 run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares).unwrap();
1601
1602 let (public_share, input_shares) = prio3.shard(&true, &nonce).unwrap();
1603 run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares).unwrap();
1604
1605 test_serialization(&prio3, &true, &nonce).unwrap();
1606
1607 let prio3_extra_helper = Prio3::new_count(3).unwrap();
1608 assert_eq!(
1609 run_vdaf(&prio3_extra_helper, &(), [true, false, false, true, true]).unwrap(),
1610 3,
1611 );
1612 }
1613
1614 #[test]
1615 fn test_prio3_sum() {
1616 let prio3 = Prio3::new_sum(3, 16).unwrap();
1617
1618 assert_eq!(
1619 run_vdaf(&prio3, &(), [0, (1 << 16) - 1, 0, 1, 1]).unwrap(),
1620 (1 << 16) + 1
1621 );
1622
1623 let mut verify_key = [0; 16];
1624 thread_rng().fill(&mut verify_key[..]);
1625 let nonce = [0; 16];
1626
1627 let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap();
1628 input_shares[0].joint_rand_blind.as_mut().unwrap().0[0] ^= 255;
1629 let result = run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares);
1630 assert_matches!(result, Err(VdafError::Uncategorized(_)));
1631
1632 let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap();
1633 assert_matches!(input_shares[0].measurement_share, Share::Leader(ref mut data) => {
1634 data[0] += Field128::one();
1635 });
1636 let result = run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares);
1637 assert_matches!(result, Err(VdafError::Uncategorized(_)));
1638
1639 let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap();
1640 assert_matches!(input_shares[0].proofs_share, Share::Leader(ref mut data) => {
1641 data[0] += Field128::one();
1642 });
1643 let result = run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares);
1644 assert_matches!(result, Err(VdafError::Uncategorized(_)));
1645
1646 test_serialization(&prio3, &1, &nonce).unwrap();
1647 }
1648
1649 #[test]
1650 fn test_prio3_sum_vec() {
1651 let prio3 = Prio3::new_sum_vec(2, 2, 20, 4).unwrap();
1652 assert_eq!(
1653 run_vdaf(
1654 &prio3,
1655 &(),
1656 [
1657 vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1],
1658 vec![0, 2, 0, 0, 1, 0, 0, 0, 1, 1, 1, 3, 0, 3, 0, 0, 0, 1, 0, 0],
1659 vec![1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1],
1660 ]
1661 )
1662 .unwrap(),
1663 vec![1, 3, 1, 0, 3, 1, 0, 1, 2, 2, 3, 3, 1, 5, 1, 2, 1, 3, 0, 2],
1664 );
1665 }
1666
1667 #[test]
1668 fn test_prio3_sum_vec_multiproof() {
1669 let prio3 = Prio3::<
1670 SumVec<Field128, ParallelSum<Field128, Mul<Field128>>>,
1671 XofTurboShake128,
1672 16,
1673 >::new(2, 2, 0xFFFF0000, SumVec::new(2, 20, 4).unwrap())
1674 .unwrap();
1675
1676 assert_eq!(
1677 run_vdaf(
1678 &prio3,
1679 &(),
1680 [
1681 vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1],
1682 vec![0, 2, 0, 0, 1, 0, 0, 0, 1, 1, 1, 3, 0, 3, 0, 0, 0, 1, 0, 0],
1683 vec![1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1],
1684 ]
1685 )
1686 .unwrap(),
1687 vec![1, 3, 1, 0, 3, 1, 0, 1, 2, 2, 3, 3, 1, 5, 1, 2, 1, 3, 0, 2],
1688 );
1689 }
1690
1691 #[test]
1692 #[cfg(feature = "multithreaded")]
1693 fn test_prio3_sum_vec_multithreaded() {
1694 let prio3 = Prio3::new_sum_vec_multithreaded(2, 2, 20, 4).unwrap();
1695 assert_eq!(
1696 run_vdaf(
1697 &prio3,
1698 &(),
1699 [
1700 vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1],
1701 vec![0, 2, 0, 0, 1, 0, 0, 0, 1, 1, 1, 3, 0, 3, 0, 0, 0, 1, 0, 0],
1702 vec![1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1],
1703 ]
1704 )
1705 .unwrap(),
1706 vec![1, 3, 1, 0, 3, 1, 0, 1, 2, 2, 3, 3, 1, 5, 1, 2, 1, 3, 0, 2],
1707 );
1708 }
1709
1710 #[test]
1711 #[cfg(feature = "experimental")]
1712 fn test_prio3_bounded_fpvec_sum_unaligned() {
1713 type P<Fx> = Prio3FixedPointBoundedL2VecSum<Fx>;
1714 #[cfg(feature = "multithreaded")]
1715 type PM<Fx> = Prio3FixedPointBoundedL2VecSumMultithreaded<Fx>;
1716 let ctor_32 = P::<FixedI32<U31>>::new_fixedpoint_boundedl2_vec_sum;
1717 #[cfg(feature = "multithreaded")]
1718 let ctor_mt_32 = PM::<FixedI32<U31>>::new_fixedpoint_boundedl2_vec_sum_multithreaded;
1719
1720 {
1721 const SIZE: usize = 5;
1722 let fp32_0 = fixed!(0: I1F31);
1723
1724 {
1726 let prio3_32 = ctor_32(2, SIZE).unwrap();
1727 test_fixed_vec::<_, _, _, SIZE>(fp32_0, prio3_32);
1728 }
1729
1730 #[cfg(feature = "multithreaded")]
1732 {
1733 let prio3_mt_32 = ctor_mt_32(2, SIZE).unwrap();
1734 test_fixed_vec::<_, _, _, SIZE>(fp32_0, prio3_mt_32);
1735 }
1736 }
1737
1738 fn test_fixed_vec<Fx, PE, M, const SIZE: usize>(
1739 fp_0: Fx,
1740 prio3: Prio3<FixedPointBoundedL2VecSum<Fx, PE, M>, XofTurboShake128, 16>,
1741 ) where
1742 Fx: Fixed + CompatibleFloat + std::ops::Neg<Output = Fx>,
1743 PE: Eq + ParallelSumGadget<Field128, PolyEval<Field128>> + Clone + 'static,
1744 M: Eq + ParallelSumGadget<Field128, Mul<Field128>> + Clone + 'static,
1745 {
1746 let fp_vec = vec![fp_0; SIZE];
1747
1748 let measurements = [fp_vec.clone(), fp_vec];
1749 assert_eq!(
1750 run_vdaf(&prio3, &(), measurements).unwrap(),
1751 vec![0.0; SIZE]
1752 );
1753 }
1754 }
1755
1756 #[test]
1757 #[cfg(feature = "experimental")]
1758 fn test_prio3_bounded_fpvec_sum() {
1759 type P<Fx> = Prio3FixedPointBoundedL2VecSum<Fx>;
1760 let ctor_16 = P::<FixedI16<U15>>::new_fixedpoint_boundedl2_vec_sum;
1761 let ctor_32 = P::<FixedI32<U31>>::new_fixedpoint_boundedl2_vec_sum;
1762 let ctor_64 = P::<FixedI64<U63>>::new_fixedpoint_boundedl2_vec_sum;
1763
1764 #[cfg(feature = "multithreaded")]
1765 type PM<Fx> = Prio3FixedPointBoundedL2VecSumMultithreaded<Fx>;
1766 #[cfg(feature = "multithreaded")]
1767 let ctor_mt_16 = PM::<FixedI16<U15>>::new_fixedpoint_boundedl2_vec_sum_multithreaded;
1768 #[cfg(feature = "multithreaded")]
1769 let ctor_mt_32 = PM::<FixedI32<U31>>::new_fixedpoint_boundedl2_vec_sum_multithreaded;
1770 #[cfg(feature = "multithreaded")]
1771 let ctor_mt_64 = PM::<FixedI64<U63>>::new_fixedpoint_boundedl2_vec_sum_multithreaded;
1772
1773 {
1774 let fp16_4_inv = fixed!(0.25: I1F15);
1776 let fp16_8_inv = fixed!(0.125: I1F15);
1777 let fp16_16_inv = fixed!(0.0625: I1F15);
1778
1779 {
1781 let prio3_16 = ctor_16(2, 3).unwrap();
1782 test_fixed(fp16_4_inv, fp16_8_inv, fp16_16_inv, prio3_16);
1783 }
1784
1785 #[cfg(feature = "multithreaded")]
1786 {
1787 let prio3_16_mt = ctor_mt_16(2, 3).unwrap();
1788 test_fixed(fp16_4_inv, fp16_8_inv, fp16_16_inv, prio3_16_mt);
1789 }
1790 }
1791
1792 {
1793 let fp32_4_inv = fixed!(0.25: I1F31);
1795 let fp32_8_inv = fixed!(0.125: I1F31);
1796 let fp32_16_inv = fixed!(0.0625: I1F31);
1797
1798 {
1799 let prio3_32 = ctor_32(2, 3).unwrap();
1800 test_fixed(fp32_4_inv, fp32_8_inv, fp32_16_inv, prio3_32);
1801 }
1802
1803 #[cfg(feature = "multithreaded")]
1804 {
1805 let prio3_32_mt = ctor_mt_32(2, 3).unwrap();
1806 test_fixed(fp32_4_inv, fp32_8_inv, fp32_16_inv, prio3_32_mt);
1807 }
1808 }
1809
1810 {
1811 let fp64_4_inv = fixed!(0.25: I1F63);
1813 let fp64_8_inv = fixed!(0.125: I1F63);
1814 let fp64_16_inv = fixed!(0.0625: I1F63);
1815
1816 {
1817 let prio3_64 = ctor_64(2, 3).unwrap();
1818 test_fixed(fp64_4_inv, fp64_8_inv, fp64_16_inv, prio3_64);
1819 }
1820
1821 #[cfg(feature = "multithreaded")]
1822 {
1823 let prio3_64_mt = ctor_mt_64(2, 3).unwrap();
1824 test_fixed(fp64_4_inv, fp64_8_inv, fp64_16_inv, prio3_64_mt);
1825 }
1826 }
1827
1828 fn test_fixed<Fx, PE, M>(
1829 fp_4_inv: Fx,
1830 fp_8_inv: Fx,
1831 fp_16_inv: Fx,
1832 prio3: Prio3<FixedPointBoundedL2VecSum<Fx, PE, M>, XofTurboShake128, 16>,
1833 ) where
1834 Fx: Fixed + CompatibleFloat + std::ops::Neg<Output = Fx>,
1835 PE: Eq + ParallelSumGadget<Field128, PolyEval<Field128>> + Clone + 'static,
1836 M: Eq + ParallelSumGadget<Field128, Mul<Field128>> + Clone + 'static,
1837 {
1838 let fp_vec1 = vec![fp_4_inv, fp_8_inv, fp_16_inv];
1839 let fp_vec2 = vec![fp_4_inv, fp_8_inv, fp_16_inv];
1840
1841 let fp_vec3 = vec![-fp_4_inv, -fp_8_inv, -fp_16_inv];
1842 let fp_vec4 = vec![-fp_4_inv, -fp_8_inv, -fp_16_inv];
1843
1844 let fp_vec5 = vec![fp_4_inv, -fp_8_inv, -fp_16_inv];
1845 let fp_vec6 = vec![fp_4_inv, fp_8_inv, fp_16_inv];
1846
1847 let fp_list = [fp_vec1, fp_vec2];
1849 assert_eq!(
1850 run_vdaf(&prio3, &(), fp_list).unwrap(),
1851 vec!(0.5, 0.25, 0.125),
1852 );
1853
1854 let fp_list2 = [fp_vec3, fp_vec4];
1856 assert_eq!(
1857 run_vdaf(&prio3, &(), fp_list2).unwrap(),
1858 vec!(-0.5, -0.25, -0.125),
1859 );
1860
1861 let fp_list3 = [fp_vec5, fp_vec6];
1863 assert_eq!(
1864 run_vdaf(&prio3, &(), fp_list3).unwrap(),
1865 vec!(0.5, 0.0, 0.0),
1866 );
1867
1868 let mut verify_key = [0; 16];
1869 let mut nonce = [0; 16];
1870 thread_rng().fill(&mut verify_key);
1871 thread_rng().fill(&mut nonce);
1872
1873 let (public_share, mut input_shares) = prio3
1874 .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce)
1875 .unwrap();
1876 input_shares[0].joint_rand_blind.as_mut().unwrap().0[0] ^= 255;
1877 let result =
1878 run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares);
1879 assert_matches!(result, Err(VdafError::Uncategorized(_)));
1880
1881 let (public_share, mut input_shares) = prio3
1882 .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce)
1883 .unwrap();
1884 assert_matches!(input_shares[0].measurement_share, Share::Leader(ref mut data) => {
1885 data[0] += Field128::one();
1886 });
1887 let result =
1888 run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares);
1889 assert_matches!(result, Err(VdafError::Uncategorized(_)));
1890
1891 let (public_share, mut input_shares) = prio3
1892 .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce)
1893 .unwrap();
1894 assert_matches!(input_shares[0].proofs_share, Share::Leader(ref mut data) => {
1895 data[0] += Field128::one();
1896 });
1897 let result =
1898 run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares);
1899 assert_matches!(result, Err(VdafError::Uncategorized(_)));
1900
1901 test_serialization(&prio3, &vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce).unwrap();
1902 }
1903 }
1904
1905 #[test]
1906 fn test_prio3_histogram() {
1907 let prio3 = Prio3::new_histogram(2, 4, 2).unwrap();
1908
1909 assert_eq!(
1910 run_vdaf(&prio3, &(), [0, 1, 2, 3]).unwrap(),
1911 vec![1, 1, 1, 1]
1912 );
1913 assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]);
1914 assert_eq!(run_vdaf(&prio3, &(), [1]).unwrap(), vec![0, 1, 0, 0]);
1915 assert_eq!(run_vdaf(&prio3, &(), [2]).unwrap(), vec![0, 0, 1, 0]);
1916 assert_eq!(run_vdaf(&prio3, &(), [3]).unwrap(), vec![0, 0, 0, 1]);
1917 test_serialization(&prio3, &3, &[0; 16]).unwrap();
1918 }
1919
1920 #[test]
1921 #[cfg(feature = "multithreaded")]
1922 fn test_prio3_histogram_multithreaded() {
1923 let prio3 = Prio3::new_histogram_multithreaded(2, 4, 2).unwrap();
1924
1925 assert_eq!(
1926 run_vdaf(&prio3, &(), [0, 1, 2, 3]).unwrap(),
1927 vec![1, 1, 1, 1]
1928 );
1929 assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]);
1930 assert_eq!(run_vdaf(&prio3, &(), [1]).unwrap(), vec![0, 1, 0, 0]);
1931 assert_eq!(run_vdaf(&prio3, &(), [2]).unwrap(), vec![0, 0, 1, 0]);
1932 assert_eq!(run_vdaf(&prio3, &(), [3]).unwrap(), vec![0, 0, 0, 1]);
1933 test_serialization(&prio3, &3, &[0; 16]).unwrap();
1934 }
1935
1936 #[test]
1937 fn test_prio3_average() {
1938 let prio3 = Prio3::new_average(2, 64).unwrap();
1939
1940 assert_eq!(run_vdaf(&prio3, &(), [17, 8]).unwrap(), 12.5f64);
1941 assert_eq!(run_vdaf(&prio3, &(), [1, 1, 1, 1]).unwrap(), 1f64);
1942 assert_eq!(run_vdaf(&prio3, &(), [0, 0, 0, 1]).unwrap(), 0.25f64);
1943 assert_eq!(
1944 run_vdaf(&prio3, &(), [1, 11, 111, 1111, 3, 8]).unwrap(),
1945 207.5f64
1946 );
1947 }
1948
1949 #[test]
1950 fn test_prio3_input_share() {
1951 let prio3 = Prio3::new_sum(5, 16).unwrap();
1952 let (_public_share, input_shares) = prio3.shard(&1, &[0; 16]).unwrap();
1953
1954 for (i, x) in input_shares.iter().enumerate() {
1956 for (j, y) in input_shares.iter().enumerate() {
1957 if i != j {
1958 if let (Share::Helper(left), Share::Helper(right)) =
1959 (&x.measurement_share, &y.measurement_share)
1960 {
1961 assert_ne!(left, right);
1962 }
1963
1964 if let (Share::Helper(left), Share::Helper(right)) =
1965 (&x.proofs_share, &y.proofs_share)
1966 {
1967 assert_ne!(left, right);
1968 }
1969
1970 assert_ne!(x.joint_rand_blind, y.joint_rand_blind);
1971 }
1972 }
1973 }
1974 }
1975
1976 fn test_serialization<T, P, const SEED_SIZE: usize>(
1977 prio3: &Prio3<T, P, SEED_SIZE>,
1978 measurement: &T::Measurement,
1979 nonce: &[u8; 16],
1980 ) -> Result<(), VdafError>
1981 where
1982 T: Type,
1983 P: Xof<SEED_SIZE>,
1984 {
1985 let mut verify_key = [0; SEED_SIZE];
1986 thread_rng().fill(&mut verify_key[..]);
1987 let (public_share, input_shares) = prio3.shard(measurement, nonce)?;
1988
1989 let encoded_public_share = public_share.get_encoded().unwrap();
1990 let decoded_public_share =
1991 Prio3PublicShare::get_decoded_with_param(prio3, &encoded_public_share)
1992 .expect("failed to decode public share");
1993 assert_eq!(decoded_public_share, public_share);
1994 assert_eq!(
1995 public_share.encoded_len().unwrap(),
1996 encoded_public_share.len()
1997 );
1998
1999 for (agg_id, input_share) in input_shares.iter().enumerate() {
2000 let encoded_input_share = input_share.get_encoded().unwrap();
2001 let decoded_input_share =
2002 Prio3InputShare::get_decoded_with_param(&(prio3, agg_id), &encoded_input_share)
2003 .expect("failed to decode input share");
2004 assert_eq!(&decoded_input_share, input_share);
2005 assert_eq!(
2006 input_share.encoded_len().unwrap(),
2007 encoded_input_share.len()
2008 );
2009 }
2010
2011 let mut prepare_shares = Vec::new();
2012 let mut last_prepare_state = None;
2013 for (agg_id, input_share) in input_shares.iter().enumerate() {
2014 let (prepare_state, prepare_share) =
2015 prio3.prepare_init(&verify_key, agg_id, &(), nonce, &public_share, input_share)?;
2016
2017 let encoded_prepare_state = prepare_state.get_encoded().unwrap();
2018 let decoded_prepare_state =
2019 Prio3PrepareState::get_decoded_with_param(&(prio3, agg_id), &encoded_prepare_state)
2020 .expect("failed to decode prepare state");
2021 assert_eq!(decoded_prepare_state, prepare_state);
2022 assert_eq!(
2023 prepare_state.encoded_len().unwrap(),
2024 encoded_prepare_state.len()
2025 );
2026
2027 let encoded_prepare_share = prepare_share.get_encoded().unwrap();
2028 let decoded_prepare_share =
2029 Prio3PrepareShare::get_decoded_with_param(&prepare_state, &encoded_prepare_share)
2030 .expect("failed to decode prepare share");
2031 assert_eq!(decoded_prepare_share, prepare_share);
2032 assert_eq!(
2033 prepare_share.encoded_len().unwrap(),
2034 encoded_prepare_share.len()
2035 );
2036
2037 prepare_shares.push(prepare_share);
2038 last_prepare_state = Some(prepare_state);
2039 }
2040
2041 let prepare_message = prio3
2042 .prepare_shares_to_prepare_message(&(), prepare_shares)
2043 .unwrap();
2044
2045 let encoded_prepare_message = prepare_message.get_encoded().unwrap();
2046 let decoded_prepare_message = Prio3PrepareMessage::get_decoded_with_param(
2047 &last_prepare_state.unwrap(),
2048 &encoded_prepare_message,
2049 )
2050 .expect("failed to decode prepare message");
2051 assert_eq!(decoded_prepare_message, prepare_message);
2052 assert_eq!(
2053 prepare_message.encoded_len().unwrap(),
2054 encoded_prepare_message.len()
2055 );
2056
2057 Ok(())
2058 }
2059
2060 #[test]
2061 fn roundtrip_output_share() {
2062 let vdaf = Prio3::new_count(2).unwrap();
2063 fieldvec_roundtrip_test::<Field64, Prio3Count, OutputShare<Field64>>(&vdaf, &(), 1);
2064
2065 let vdaf = Prio3::new_sum(2, 17).unwrap();
2066 fieldvec_roundtrip_test::<Field128, Prio3Sum, OutputShare<Field128>>(&vdaf, &(), 1);
2067
2068 let vdaf = Prio3::new_histogram(2, 12, 3).unwrap();
2069 fieldvec_roundtrip_test::<Field128, Prio3Histogram, OutputShare<Field128>>(&vdaf, &(), 12);
2070 }
2071
2072 #[test]
2073 fn roundtrip_aggregate_share() {
2074 let vdaf = Prio3::new_count(2).unwrap();
2075 fieldvec_roundtrip_test::<Field64, Prio3Count, AggregateShare<Field64>>(&vdaf, &(), 1);
2076
2077 let vdaf = Prio3::new_sum(2, 17).unwrap();
2078 fieldvec_roundtrip_test::<Field128, Prio3Sum, AggregateShare<Field128>>(&vdaf, &(), 1);
2079
2080 let vdaf = Prio3::new_histogram(2, 12, 3).unwrap();
2081 fieldvec_roundtrip_test::<Field128, Prio3Histogram, AggregateShare<Field128>>(
2082 &vdaf,
2083 &(),
2084 12,
2085 );
2086 }
2087
2088 #[test]
2089 fn public_share_equality_test() {
2090 equality_comparison_test(&[
2091 Prio3PublicShare {
2092 joint_rand_parts: Some(Vec::from([Seed([0])])),
2093 },
2094 Prio3PublicShare {
2095 joint_rand_parts: Some(Vec::from([Seed([1])])),
2096 },
2097 Prio3PublicShare {
2098 joint_rand_parts: None,
2099 },
2100 ])
2101 }
2102
2103 #[test]
2104 fn input_share_equality_test() {
2105 equality_comparison_test(&[
2106 Prio3InputShare {
2108 measurement_share: Share::Leader(Vec::from([0])),
2109 proofs_share: Share::Leader(Vec::from([1])),
2110 joint_rand_blind: Some(Seed([2])),
2111 },
2112 Prio3InputShare {
2114 measurement_share: Share::Leader(Vec::from([100])),
2115 proofs_share: Share::Leader(Vec::from([1])),
2116 joint_rand_blind: Some(Seed([2])),
2117 },
2118 Prio3InputShare {
2120 measurement_share: Share::Leader(Vec::from([0])),
2121 proofs_share: Share::Leader(Vec::from([101])),
2122 joint_rand_blind: Some(Seed([2])),
2123 },
2124 Prio3InputShare {
2126 measurement_share: Share::Leader(Vec::from([0])),
2127 proofs_share: Share::Leader(Vec::from([1])),
2128 joint_rand_blind: Some(Seed([102])),
2129 },
2130 Prio3InputShare {
2132 measurement_share: Share::Leader(Vec::from([0])),
2133 proofs_share: Share::Leader(Vec::from([1])),
2134 joint_rand_blind: None,
2135 },
2136 ])
2137 }
2138
2139 #[test]
2140 fn prepare_share_equality_test() {
2141 equality_comparison_test(&[
2142 Prio3PrepareShare {
2144 verifiers: Vec::from([0]),
2145 joint_rand_part: Some(Seed([1])),
2146 },
2147 Prio3PrepareShare {
2149 verifiers: Vec::from([100]),
2150 joint_rand_part: Some(Seed([1])),
2151 },
2152 Prio3PrepareShare {
2154 verifiers: Vec::from([0]),
2155 joint_rand_part: Some(Seed([101])),
2156 },
2157 Prio3PrepareShare {
2159 verifiers: Vec::from([0]),
2160 joint_rand_part: None,
2161 },
2162 ])
2163 }
2164
2165 #[test]
2166 fn prepare_message_equality_test() {
2167 equality_comparison_test(&[
2168 Prio3PrepareMessage {
2170 joint_rand_seed: Some(Seed([0])),
2171 },
2172 Prio3PrepareMessage {
2174 joint_rand_seed: Some(Seed([100])),
2175 },
2176 Prio3PrepareMessage {
2178 joint_rand_seed: None,
2179 },
2180 ])
2181 }
2182
2183 #[test]
2184 fn prepare_state_equality_test() {
2185 equality_comparison_test(&[
2186 Prio3PrepareState {
2188 measurement_share: Share::Leader(Vec::from([0])),
2189 joint_rand_seed: Some(Seed([1])),
2190 agg_id: 2,
2191 verifiers_len: 3,
2192 },
2193 Prio3PrepareState {
2195 measurement_share: Share::Leader(Vec::from([100])),
2196 joint_rand_seed: Some(Seed([1])),
2197 agg_id: 2,
2198 verifiers_len: 3,
2199 },
2200 Prio3PrepareState {
2202 measurement_share: Share::Leader(Vec::from([0])),
2203 joint_rand_seed: Some(Seed([101])),
2204 agg_id: 2,
2205 verifiers_len: 3,
2206 },
2207 Prio3PrepareState {
2209 measurement_share: Share::Leader(Vec::from([0])),
2210 joint_rand_seed: None,
2211 agg_id: 2,
2212 verifiers_len: 3,
2213 },
2214 Prio3PrepareState {
2216 measurement_share: Share::Leader(Vec::from([0])),
2217 joint_rand_seed: Some(Seed([1])),
2218 agg_id: 102,
2219 verifiers_len: 3,
2220 },
2221 Prio3PrepareState {
2223 measurement_share: Share::Leader(Vec::from([0])),
2224 joint_rand_seed: Some(Seed([1])),
2225 agg_id: 2,
2226 verifiers_len: 103,
2227 },
2228 ])
2229 }
2230
2231 #[test]
2232 fn test_optimal_chunk_length() {
2233 optimal_chunk_length(0);
2235
2236 assert_eq!(optimal_chunk_length(1), 1);
2238 assert_eq!(optimal_chunk_length(2), 2);
2239 assert_eq!(optimal_chunk_length(3), 1);
2240 assert_eq!(optimal_chunk_length(18), 6);
2241 assert_eq!(optimal_chunk_length(19), 3);
2242
2243 assert_eq!(optimal_chunk_length(40), 6);
2245 assert_eq!(optimal_chunk_length(10_000), 79);
2246 assert_eq!(optimal_chunk_length(100_000), 393);
2247
2248 for measurement_length in [2, 3, 4, 5, 18, 19, 40] {
2250 let optimal_chunk_length = optimal_chunk_length(measurement_length);
2251 let optimal_proof_length = Histogram::<Field128, ParallelSum<_, _>>::new(
2252 measurement_length,
2253 optimal_chunk_length,
2254 )
2255 .unwrap()
2256 .proof_len();
2257 for chunk_length in 1..=measurement_length {
2258 let proof_length =
2259 Histogram::<Field128, ParallelSum<_, _>>::new(measurement_length, chunk_length)
2260 .unwrap()
2261 .proof_len();
2262 assert!(proof_length >= optimal_proof_length);
2263 }
2264 }
2265 }
2266}