prio/vdaf/
prio3.rs

1// SPDX-License-Identifier: MPL-2.0
2
3//! Implementation of the Prio3 VDAF [[draft-irtf-cfrg-vdaf-08]].
4//!
5//! **WARNING:** This code has not undergone significant security analysis. Use at your own risk.
6//!
7//! Prio3 is based on the Prio system desigend by Dan Boneh and Henry Corrigan-Gibbs and presented
8//! at NSDI 2017 [[CGB17]]. However, it incorporates a few techniques from Boneh et al., CRYPTO
9//! 2019 [[BBCG+19]], that lead to substantial improvements in terms of run time and communication
10//! cost. The security of the construction was analyzed in [[DPRS23]].
11//!
12//! Prio3 is a transformation of a Fully Linear Proof (FLP) system [[draft-irtf-cfrg-vdaf-08]] into
13//! a VDAF. The base type, [`Prio3`], supports a wide variety of aggregation functions, some of
14//! which are instantiated here:
15//!
16//! - [`Prio3Count`] for aggregating a counter (*)
17//! - [`Prio3Sum`] for copmputing the sum of integers (*)
18//! - [`Prio3SumVec`] for aggregating a vector of integers
19//! - [`Prio3Histogram`] for estimating a distribution via a histogram (*)
20//!
21//! Additional types can be constructed from [`Prio3`] as needed.
22//!
23//! (*) denotes that the type is specified in [[draft-irtf-cfrg-vdaf-08]].
24//!
25//! [BBCG+19]: https://ia.cr/2019/188
26//! [CGB17]: https://crypto.stanford.edu/prio/
27//! [DPRS23]: https://ia.cr/2023/130
28//! [draft-irtf-cfrg-vdaf-08]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/08/
29
30use super::xof::XofTurboShake128;
31#[cfg(feature = "experimental")]
32use super::AggregatorWithNoise;
33use crate::codec::{CodecError, Decode, Encode, ParameterizedDecode};
34#[cfg(feature = "experimental")]
35use crate::dp::DifferentialPrivacyStrategy;
36use crate::field::{decode_fieldvec, FftFriendlyFieldElement, FieldElement};
37use crate::field::{Field128, Field64};
38#[cfg(feature = "multithreaded")]
39use crate::flp::gadgets::ParallelSumMultithreaded;
40#[cfg(feature = "experimental")]
41use crate::flp::gadgets::PolyEval;
42use crate::flp::gadgets::{Mul, ParallelSum};
43#[cfg(feature = "experimental")]
44use crate::flp::types::fixedpoint_l2::{
45    compatible_float::CompatibleFloat, FixedPointBoundedL2VecSum,
46};
47use crate::flp::types::{Average, Count, Histogram, Sum, SumVec};
48use crate::flp::Type;
49#[cfg(feature = "experimental")]
50use crate::flp::TypeWithNoise;
51use crate::prng::Prng;
52use crate::vdaf::xof::{IntoFieldVec, Seed, Xof};
53use crate::vdaf::{
54    Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, PrepareTransition,
55    Share, ShareDecodingParameter, Vdaf, VdafError,
56};
57#[cfg(feature = "experimental")]
58use fixed::traits::Fixed;
59use std::borrow::Cow;
60use std::convert::TryFrom;
61use std::fmt::Debug;
62use std::io::Cursor;
63use std::iter::{self, IntoIterator};
64use std::marker::PhantomData;
65use subtle::{Choice, ConstantTimeEq};
66
67const DST_MEASUREMENT_SHARE: u16 = 1;
68const DST_PROOF_SHARE: u16 = 2;
69const DST_JOINT_RANDOMNESS: u16 = 3;
70const DST_PROVE_RANDOMNESS: u16 = 4;
71const DST_QUERY_RANDOMNESS: u16 = 5;
72const DST_JOINT_RAND_SEED: u16 = 6;
73const DST_JOINT_RAND_PART: u16 = 7;
74
75/// The count type. Each measurement is an integer in `[0,2)` and the aggregate result is the sum.
76pub type Prio3Count = Prio3<Count<Field64>, XofTurboShake128, 16>;
77
78impl Prio3Count {
79    /// Construct an instance of Prio3Count with the given number of aggregators.
80    pub fn new_count(num_aggregators: u8) -> Result<Self, VdafError> {
81        Prio3::new(num_aggregators, 1, 0x00000000, Count::new())
82    }
83}
84
85/// The count-vector type. Each measurement is a vector of integers in `[0,2^bits)` and the
86/// aggregate is the element-wise sum.
87pub type Prio3SumVec =
88    Prio3<SumVec<Field128, ParallelSum<Field128, Mul<Field128>>>, XofTurboShake128, 16>;
89
90impl Prio3SumVec {
91    /// Construct an instance of Prio3SumVec with the given number of aggregators. `bits` defines
92    /// the bit width of each summand of the measurement; `len` defines the length of the
93    /// measurement vector.
94    pub fn new_sum_vec(
95        num_aggregators: u8,
96        bits: usize,
97        len: usize,
98        chunk_length: usize,
99    ) -> Result<Self, VdafError> {
100        Prio3::new(
101            num_aggregators,
102            1,
103            0x00000002,
104            SumVec::new(bits, len, chunk_length)?,
105        )
106    }
107}
108
109/// Like [`Prio3SumVec`] except this type uses multithreading to improve sharding and preparation
110/// time. Note that the improvement is only noticeable for very large input lengths.
111#[cfg(feature = "multithreaded")]
112#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
113pub type Prio3SumVecMultithreaded = Prio3<
114    SumVec<Field128, ParallelSumMultithreaded<Field128, Mul<Field128>>>,
115    XofTurboShake128,
116    16,
117>;
118
119#[cfg(feature = "multithreaded")]
120impl Prio3SumVecMultithreaded {
121    /// Construct an instance of Prio3SumVecMultithreaded with the given number of
122    /// aggregators. `bits` defines the bit width of each summand of the measurement; `len` defines
123    /// the length of the measurement vector.
124    pub fn new_sum_vec_multithreaded(
125        num_aggregators: u8,
126        bits: usize,
127        len: usize,
128        chunk_length: usize,
129    ) -> Result<Self, VdafError> {
130        Prio3::new(
131            num_aggregators,
132            1,
133            0x00000002,
134            SumVec::new(bits, len, chunk_length)?,
135        )
136    }
137}
138
139/// The sum type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the
140/// aggregate is the sum.
141pub type Prio3Sum = Prio3<Sum<Field128>, XofTurboShake128, 16>;
142
143impl Prio3Sum {
144    /// Construct an instance of Prio3Sum with the given number of aggregators and required bit
145    /// length. The bit length must not exceed 64.
146    pub fn new_sum(num_aggregators: u8, bits: usize) -> Result<Self, VdafError> {
147        if bits > 64 {
148            return Err(VdafError::Uncategorized(format!(
149                "bit length ({bits}) exceeds limit for aggregate type (64)"
150            )));
151        }
152
153        Prio3::new(num_aggregators, 1, 0x00000001, Sum::new(bits)?)
154    }
155}
156
157/// The fixed point vector sum type. Each measurement is a vector of fixed point numbers
158/// and the aggregate is the sum represented as 64-bit floats. The preparation phase
159/// ensures the L2 norm of the input vector is < 1.
160///
161/// This is useful for aggregating gradients in a federated version of
162/// [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent) with
163/// [differential privacy](https://en.wikipedia.org/wiki/Differential_privacy),
164/// useful, e.g., for [differentially private deep learning](https://arxiv.org/pdf/1607.00133.pdf).
165/// The bound on input norms is required for differential privacy. The fixed point representation
166/// allows an easy conversion to the integer type used in internal computation, while leaving
167/// conversion to the client. The model itself will have floating point parameters, so the output
168/// sum has that type as well.
169#[cfg(feature = "experimental")]
170#[cfg_attr(docsrs, doc(cfg(feature = "experimental")))]
171pub type Prio3FixedPointBoundedL2VecSum<Fx> = Prio3<
172    FixedPointBoundedL2VecSum<
173        Fx,
174        ParallelSum<Field128, PolyEval<Field128>>,
175        ParallelSum<Field128, Mul<Field128>>,
176    >,
177    XofTurboShake128,
178    16,
179>;
180
181#[cfg(feature = "experimental")]
182impl<Fx: Fixed + CompatibleFloat> Prio3FixedPointBoundedL2VecSum<Fx> {
183    /// Construct an instance of this VDAF with the given number of aggregators and number of
184    /// vector entries.
185    pub fn new_fixedpoint_boundedl2_vec_sum(
186        num_aggregators: u8,
187        entries: usize,
188    ) -> Result<Self, VdafError> {
189        check_num_aggregators(num_aggregators)?;
190        Prio3::new(
191            num_aggregators,
192            1,
193            0xFFFF0000,
194            FixedPointBoundedL2VecSum::new(entries)?,
195        )
196    }
197}
198
199/// The fixed point vector sum type. Each measurement is a vector of fixed point numbers
200/// and the aggregate is the sum represented as 64-bit floats. The verification function
201/// ensures the L2 norm of the input vector is < 1.
202#[cfg(all(feature = "experimental", feature = "multithreaded"))]
203#[cfg_attr(
204    docsrs,
205    doc(cfg(all(feature = "experimental", feature = "multithreaded")))
206)]
207pub type Prio3FixedPointBoundedL2VecSumMultithreaded<Fx> = Prio3<
208    FixedPointBoundedL2VecSum<
209        Fx,
210        ParallelSumMultithreaded<Field128, PolyEval<Field128>>,
211        ParallelSumMultithreaded<Field128, Mul<Field128>>,
212    >,
213    XofTurboShake128,
214    16,
215>;
216
217#[cfg(all(feature = "experimental", feature = "multithreaded"))]
218impl<Fx: Fixed + CompatibleFloat> Prio3FixedPointBoundedL2VecSumMultithreaded<Fx> {
219    /// Construct an instance of this VDAF with the given number of aggregators and number of
220    /// vector entries.
221    pub fn new_fixedpoint_boundedl2_vec_sum_multithreaded(
222        num_aggregators: u8,
223        entries: usize,
224    ) -> Result<Self, VdafError> {
225        check_num_aggregators(num_aggregators)?;
226        Prio3::new(
227            num_aggregators,
228            1,
229            0xFFFF0000,
230            FixedPointBoundedL2VecSum::new(entries)?,
231        )
232    }
233}
234
235/// The histogram type. Each measurement is an integer in `[0, length)` and the result is a
236/// histogram counting the number of occurrences of each measurement.
237pub type Prio3Histogram =
238    Prio3<Histogram<Field128, ParallelSum<Field128, Mul<Field128>>>, XofTurboShake128, 16>;
239
240impl Prio3Histogram {
241    /// Constructs an instance of Prio3Histogram with the given number of aggregators,
242    /// number of buckets, and parallel sum gadget chunk length.
243    pub fn new_histogram(
244        num_aggregators: u8,
245        length: usize,
246        chunk_length: usize,
247    ) -> Result<Self, VdafError> {
248        Prio3::new(
249            num_aggregators,
250            1,
251            0x00000003,
252            Histogram::new(length, chunk_length)?,
253        )
254    }
255}
256
257/// Like [`Prio3Histogram`] except this type uses multithreading to improve sharding and preparation
258/// time. Note that this improvement is only noticeable for very large input lengths.
259#[cfg(feature = "multithreaded")]
260#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
261pub type Prio3HistogramMultithreaded = Prio3<
262    Histogram<Field128, ParallelSumMultithreaded<Field128, Mul<Field128>>>,
263    XofTurboShake128,
264    16,
265>;
266
267#[cfg(feature = "multithreaded")]
268impl Prio3HistogramMultithreaded {
269    /// Construct an instance of Prio3HistogramMultithreaded with the given number of aggregators,
270    /// number of buckets, and parallel sum gadget chunk length.
271    pub fn new_histogram_multithreaded(
272        num_aggregators: u8,
273        length: usize,
274        chunk_length: usize,
275    ) -> Result<Self, VdafError> {
276        Prio3::new(
277            num_aggregators,
278            1,
279            0x00000003,
280            Histogram::new(length, chunk_length)?,
281        )
282    }
283}
284
285/// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and
286/// the aggregate is the arithmetic average.
287pub type Prio3Average = Prio3<Average<Field128>, XofTurboShake128, 16>;
288
289impl Prio3Average {
290    /// Construct an instance of Prio3Average with the given number of aggregators and required bit
291    /// length. The bit length must not exceed 64.
292    pub fn new_average(num_aggregators: u8, bits: usize) -> Result<Self, VdafError> {
293        check_num_aggregators(num_aggregators)?;
294
295        if bits > 64 {
296            return Err(VdafError::Uncategorized(format!(
297                "bit length ({bits}) exceeds limit for aggregate type (64)"
298            )));
299        }
300
301        Ok(Prio3 {
302            num_aggregators,
303            num_proofs: 1,
304            algorithm_id: 0xFFFF0000,
305            typ: Average::new(bits)?,
306            phantom: PhantomData,
307        })
308    }
309}
310
311/// The base type for Prio3.
312///
313/// An instance of Prio3 is determined by:
314///
315/// - a [`Type`] that defines the set of valid input measurements; and
316/// - a [`Xof`] for deriving vectors of field elements from seeds.
317///
318/// New instances can be defined by aliasing the base type. For example, [`Prio3Count`] is an alias
319/// for `Prio3<Count<Field64>, XofTurboShake128, 16>`.
320///
321/// ```
322/// use prio::vdaf::{
323///     Aggregator, Client, Collector, PrepareTransition,
324///     prio3::Prio3,
325/// };
326/// use rand::prelude::*;
327///
328/// let num_shares = 2;
329/// let vdaf = Prio3::new_count(num_shares).unwrap();
330///
331/// let mut out_shares = vec![vec![]; num_shares.into()];
332/// let mut rng = thread_rng();
333/// let verify_key = rng.gen();
334/// let measurements = [false, true, true, true, false];
335/// for measurement in measurements {
336///     // Shard
337///     let nonce = rng.gen::<[u8; 16]>();
338///     let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap();
339///
340///     // Prepare
341///     let mut prep_states = vec![];
342///     let mut prep_shares = vec![];
343///     for (agg_id, input_share) in input_shares.iter().enumerate() {
344///         let (state, share) = vdaf.prepare_init(
345///             &verify_key,
346///             agg_id,
347///             &(),
348///             &nonce,
349///             &public_share,
350///             input_share
351///         ).unwrap();
352///         prep_states.push(state);
353///         prep_shares.push(share);
354///     }
355///     let prep_msg = vdaf.prepare_shares_to_prepare_message(&(), prep_shares).unwrap();
356///
357///     for (agg_id, state) in prep_states.into_iter().enumerate() {
358///         let out_share = match vdaf.prepare_next(state, prep_msg.clone()).unwrap() {
359///             PrepareTransition::Finish(out_share) => out_share,
360///             _ => panic!("unexpected transition"),
361///         };
362///         out_shares[agg_id].push(out_share);
363///     }
364/// }
365///
366/// // Aggregate
367/// let agg_shares = out_shares.into_iter()
368///     .map(|o| vdaf.aggregate(&(), o).unwrap());
369///
370/// // Unshard
371/// let agg_res = vdaf.unshard(&(), agg_shares, measurements.len()).unwrap();
372/// assert_eq!(agg_res, 3);
373/// ```
374#[derive(Clone, Debug)]
375pub struct Prio3<T, P, const SEED_SIZE: usize>
376where
377    T: Type,
378    P: Xof<SEED_SIZE>,
379{
380    num_aggregators: u8,
381    num_proofs: u8,
382    algorithm_id: u32,
383    typ: T,
384    phantom: PhantomData<P>,
385}
386
387impl<T, P, const SEED_SIZE: usize> Prio3<T, P, SEED_SIZE>
388where
389    T: Type,
390    P: Xof<SEED_SIZE>,
391{
392    /// Construct an instance of this Prio3 VDAF with the given number of aggregators, number of
393    /// proofs to generate and verify, the algorithm ID, and the underlying type.
394    pub fn new(
395        num_aggregators: u8,
396        num_proofs: u8,
397        algorithm_id: u32,
398        typ: T,
399    ) -> Result<Self, VdafError> {
400        check_num_aggregators(num_aggregators)?;
401        if num_proofs == 0 {
402            return Err(VdafError::Uncategorized(
403                "num_proofs must be at least 1".to_string(),
404            ));
405        }
406
407        Ok(Self {
408            num_aggregators,
409            num_proofs,
410            algorithm_id,
411            typ,
412            phantom: PhantomData,
413        })
414    }
415
416    /// The output length of the underlying FLP.
417    pub fn output_len(&self) -> usize {
418        self.typ.output_len()
419    }
420
421    /// The verifier length of the underlying FLP.
422    pub fn verifier_len(&self) -> usize {
423        self.typ.verifier_len()
424    }
425
426    #[inline]
427    fn num_proofs(&self) -> usize {
428        self.num_proofs.into()
429    }
430
431    fn derive_prove_rands(&self, prove_rand_seed: &Seed<SEED_SIZE>) -> Vec<T::Field> {
432        P::seed_stream(
433            prove_rand_seed,
434            &self.domain_separation_tag(DST_PROVE_RANDOMNESS),
435            &[self.num_proofs],
436        )
437        .into_field_vec(self.typ.prove_rand_len() * self.num_proofs())
438    }
439
440    fn derive_joint_rand_seed<'a>(
441        &self,
442        joint_rand_parts: impl Iterator<Item = &'a Seed<SEED_SIZE>>,
443    ) -> Seed<SEED_SIZE> {
444        let mut xof = P::init(
445            &[0; SEED_SIZE],
446            &self.domain_separation_tag(DST_JOINT_RAND_SEED),
447        );
448        for part in joint_rand_parts {
449            xof.update(part.as_ref());
450        }
451        xof.into_seed()
452    }
453
454    fn derive_joint_rands<'a>(
455        &self,
456        joint_rand_parts: impl Iterator<Item = &'a Seed<SEED_SIZE>>,
457    ) -> (Seed<SEED_SIZE>, Vec<T::Field>) {
458        let joint_rand_seed = self.derive_joint_rand_seed(joint_rand_parts);
459        let joint_rands = P::seed_stream(
460            &joint_rand_seed,
461            &self.domain_separation_tag(DST_JOINT_RANDOMNESS),
462            &[self.num_proofs],
463        )
464        .into_field_vec(self.typ.joint_rand_len() * self.num_proofs());
465
466        (joint_rand_seed, joint_rands)
467    }
468
469    fn derive_helper_proofs_share(
470        &self,
471        proofs_share_seed: &Seed<SEED_SIZE>,
472        agg_id: u8,
473    ) -> Prng<T::Field, P::SeedStream> {
474        Prng::from_seed_stream(P::seed_stream(
475            proofs_share_seed,
476            &self.domain_separation_tag(DST_PROOF_SHARE),
477            &[self.num_proofs, agg_id],
478        ))
479    }
480
481    fn derive_query_rands(&self, verify_key: &[u8; SEED_SIZE], nonce: &[u8; 16]) -> Vec<T::Field> {
482        let mut xof = P::init(
483            verify_key,
484            &self.domain_separation_tag(DST_QUERY_RANDOMNESS),
485        );
486        xof.update(&[self.num_proofs]);
487        xof.update(nonce);
488        xof.into_seed_stream()
489            .into_field_vec(self.typ.query_rand_len() * self.num_proofs())
490    }
491
492    fn random_size(&self) -> usize {
493        if self.typ.joint_rand_len() == 0 {
494            // Two seeds per helper for measurement and proof shares, plus one seed for proving
495            // randomness.
496            (usize::from(self.num_aggregators - 1) * 2 + 1) * SEED_SIZE
497        } else {
498            (
499                // Two seeds per helper for measurement and proof shares
500                usize::from(self.num_aggregators - 1) * 2
501                // One seed for proving randomness
502                + 1
503                // One seed per aggregator for joint randomness blinds
504                + usize::from(self.num_aggregators)
505            ) * SEED_SIZE
506        }
507    }
508
509    #[allow(clippy::type_complexity)]
510    pub(crate) fn shard_with_random<const N: usize>(
511        &self,
512        measurement: &T::Measurement,
513        nonce: &[u8; N],
514        random: &[u8],
515    ) -> Result<
516        (
517            Prio3PublicShare<SEED_SIZE>,
518            Vec<Prio3InputShare<T::Field, SEED_SIZE>>,
519        ),
520        VdafError,
521    > {
522        if random.len() != self.random_size() {
523            return Err(VdafError::Uncategorized(
524                "incorrect random input length".to_string(),
525            ));
526        }
527        let mut random_seeds = random.chunks_exact(SEED_SIZE);
528        let num_aggregators = self.num_aggregators;
529        let encoded_measurement = self.typ.encode_measurement(measurement)?;
530
531        // Generate the measurement shares and compute the joint randomness.
532        let mut helper_shares = Vec::with_capacity(num_aggregators as usize - 1);
533        let mut helper_joint_rand_parts = if self.typ.joint_rand_len() > 0 {
534            Some(Vec::with_capacity(num_aggregators as usize - 1))
535        } else {
536            None
537        };
538        let mut leader_measurement_share = encoded_measurement.clone();
539        for agg_id in 1..num_aggregators {
540            // The Option from the ChunksExact iterator is okay to unwrap because we checked that
541            // the randomness slice is long enough for this VDAF. The slice-to-array conversion
542            // Result is okay to unwrap because the ChunksExact iterator always returns slices of
543            // the correct length.
544            let measurement_share_seed = random_seeds.next().unwrap().try_into().unwrap();
545            let proof_share_seed = random_seeds.next().unwrap().try_into().unwrap();
546            let measurement_share_prng: Prng<T::Field, _> = Prng::from_seed_stream(P::seed_stream(
547                &Seed(measurement_share_seed),
548                &self.domain_separation_tag(DST_MEASUREMENT_SHARE),
549                &[agg_id],
550            ));
551            let joint_rand_blind = if let Some(helper_joint_rand_parts) =
552                helper_joint_rand_parts.as_mut()
553            {
554                let joint_rand_blind = random_seeds.next().unwrap().try_into().unwrap();
555                let mut joint_rand_part_xof = P::init(
556                    &joint_rand_blind,
557                    &self.domain_separation_tag(DST_JOINT_RAND_PART),
558                );
559                joint_rand_part_xof.update(&[agg_id]); // Aggregator ID
560                joint_rand_part_xof.update(nonce);
561
562                let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE);
563                for (x, y) in leader_measurement_share
564                    .iter_mut()
565                    .zip(measurement_share_prng)
566                {
567                    *x -= y;
568                    y.encode(&mut encoding_buffer).map_err(|_| {
569                        VdafError::Uncategorized("failed to encode measurement share".to_string())
570                    })?;
571                    joint_rand_part_xof.update(&encoding_buffer);
572                    encoding_buffer.clear();
573                }
574
575                helper_joint_rand_parts.push(joint_rand_part_xof.into_seed());
576
577                Some(joint_rand_blind)
578            } else {
579                for (x, y) in leader_measurement_share
580                    .iter_mut()
581                    .zip(measurement_share_prng)
582                {
583                    *x -= y;
584                }
585                None
586            };
587            let helper =
588                HelperShare::from_seeds(measurement_share_seed, proof_share_seed, joint_rand_blind);
589            helper_shares.push(helper);
590        }
591
592        let mut leader_blind_opt = None;
593        let public_share = Prio3PublicShare {
594            joint_rand_parts: helper_joint_rand_parts
595                .as_ref()
596                .map(
597                    |helper_joint_rand_parts| -> Result<Vec<Seed<SEED_SIZE>>, VdafError> {
598                        let leader_blind_bytes = random_seeds.next().unwrap().try_into().unwrap();
599                        let leader_blind = Seed::from_bytes(leader_blind_bytes);
600
601                        let mut joint_rand_part_xof = P::init(
602                            leader_blind.as_ref(),
603                            &self.domain_separation_tag(DST_JOINT_RAND_PART),
604                        );
605                        joint_rand_part_xof.update(&[0]); // Aggregator ID
606                        joint_rand_part_xof.update(nonce);
607                        let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE);
608                        for x in leader_measurement_share.iter() {
609                            x.encode(&mut encoding_buffer).map_err(|_| {
610                                VdafError::Uncategorized(
611                                    "failed to encode measurement share".to_string(),
612                                )
613                            })?;
614                            joint_rand_part_xof.update(&encoding_buffer);
615                            encoding_buffer.clear();
616                        }
617                        leader_blind_opt = Some(leader_blind);
618
619                        let leader_joint_rand_seed_part = joint_rand_part_xof.into_seed();
620
621                        let mut vec = Vec::with_capacity(self.num_aggregators());
622                        vec.push(leader_joint_rand_seed_part);
623                        vec.extend(helper_joint_rand_parts.iter().cloned());
624                        Ok(vec)
625                    },
626                )
627                .transpose()?,
628        };
629
630        // Compute the joint randomness.
631        let joint_rands = public_share
632            .joint_rand_parts
633            .as_ref()
634            .map(|joint_rand_parts| self.derive_joint_rands(joint_rand_parts.iter()).1)
635            .unwrap_or_default();
636
637        // Generate the proofs.
638        let prove_rands = self.derive_prove_rands(&Seed::from_bytes(
639            random_seeds.next().unwrap().try_into().unwrap(),
640        ));
641        let mut leader_proofs_share = Vec::with_capacity(self.typ.proof_len() * self.num_proofs());
642        for p in 0..self.num_proofs() {
643            let prove_rand =
644                &prove_rands[p * self.typ.prove_rand_len()..(p + 1) * self.typ.prove_rand_len()];
645            let joint_rand =
646                &joint_rands[p * self.typ.joint_rand_len()..(p + 1) * self.typ.joint_rand_len()];
647
648            leader_proofs_share.append(&mut self.typ.prove(
649                &encoded_measurement,
650                prove_rand,
651                joint_rand,
652            )?);
653        }
654
655        // Generate the proof shares and distribute the joint randomness seed hints.
656        for (j, helper) in helper_shares.iter_mut().enumerate() {
657            for (x, y) in
658                leader_proofs_share
659                    .iter_mut()
660                    .zip(self.derive_helper_proofs_share(
661                        &helper.proofs_share,
662                        u8::try_from(j).unwrap() + 1,
663                    ))
664                    .take(self.typ.proof_len() * self.num_proofs())
665            {
666                *x -= y;
667            }
668        }
669
670        // Prep the output messages.
671        let mut out = Vec::with_capacity(num_aggregators as usize);
672        out.push(Prio3InputShare {
673            measurement_share: Share::Leader(leader_measurement_share),
674            proofs_share: Share::Leader(leader_proofs_share),
675            joint_rand_blind: leader_blind_opt,
676        });
677
678        for helper in helper_shares.into_iter() {
679            out.push(Prio3InputShare {
680                measurement_share: Share::Helper(helper.measurement_share),
681                proofs_share: Share::Helper(helper.proofs_share),
682                joint_rand_blind: helper.joint_rand_blind,
683            });
684        }
685
686        Ok((public_share, out))
687    }
688
689    fn role_try_from(&self, agg_id: usize) -> Result<u8, VdafError> {
690        if agg_id >= self.num_aggregators as usize {
691            return Err(VdafError::Uncategorized("unexpected aggregator id".into()));
692        }
693        Ok(u8::try_from(agg_id).unwrap())
694    }
695}
696
697impl<T, P, const SEED_SIZE: usize> Vdaf for Prio3<T, P, SEED_SIZE>
698where
699    T: Type,
700    P: Xof<SEED_SIZE>,
701{
702    type Measurement = T::Measurement;
703    type AggregateResult = T::AggregateResult;
704    type AggregationParam = ();
705    type PublicShare = Prio3PublicShare<SEED_SIZE>;
706    type InputShare = Prio3InputShare<T::Field, SEED_SIZE>;
707    type OutputShare = OutputShare<T::Field>;
708    type AggregateShare = AggregateShare<T::Field>;
709
710    fn algorithm_id(&self) -> u32 {
711        self.algorithm_id
712    }
713
714    fn num_aggregators(&self) -> usize {
715        self.num_aggregators as usize
716    }
717}
718
719/// Message broadcast by the [`Client`] to every [`Aggregator`] during the Sharding phase.
720#[derive(Clone, Debug)]
721pub struct Prio3PublicShare<const SEED_SIZE: usize> {
722    /// Contributions to the joint randomness from every aggregator's share.
723    joint_rand_parts: Option<Vec<Seed<SEED_SIZE>>>,
724}
725
726impl<const SEED_SIZE: usize> Encode for Prio3PublicShare<SEED_SIZE> {
727    fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
728        if let Some(joint_rand_parts) = self.joint_rand_parts.as_ref() {
729            for part in joint_rand_parts.iter() {
730                part.encode(bytes)?;
731            }
732        }
733        Ok(())
734    }
735
736    fn encoded_len(&self) -> Option<usize> {
737        if let Some(joint_rand_parts) = self.joint_rand_parts.as_ref() {
738            // Each seed has the same size.
739            Some(SEED_SIZE * joint_rand_parts.len())
740        } else {
741            Some(0)
742        }
743    }
744}
745
746impl<const SEED_SIZE: usize> PartialEq for Prio3PublicShare<SEED_SIZE> {
747    fn eq(&self, other: &Self) -> bool {
748        self.ct_eq(other).into()
749    }
750}
751
752impl<const SEED_SIZE: usize> Eq for Prio3PublicShare<SEED_SIZE> {}
753
754impl<const SEED_SIZE: usize> ConstantTimeEq for Prio3PublicShare<SEED_SIZE> {
755    fn ct_eq(&self, other: &Self) -> Choice {
756        // We allow short-circuiting on the presence or absence of the joint_rand_parts.
757        option_ct_eq(
758            self.joint_rand_parts.as_deref(),
759            other.joint_rand_parts.as_deref(),
760        )
761    }
762}
763
764impl<T, P, const SEED_SIZE: usize> ParameterizedDecode<Prio3<T, P, SEED_SIZE>>
765    for Prio3PublicShare<SEED_SIZE>
766where
767    T: Type,
768    P: Xof<SEED_SIZE>,
769{
770    fn decode_with_param(
771        decoding_parameter: &Prio3<T, P, SEED_SIZE>,
772        bytes: &mut Cursor<&[u8]>,
773    ) -> Result<Self, CodecError> {
774        if decoding_parameter.typ.joint_rand_len() > 0 {
775            let joint_rand_parts = iter::repeat_with(|| Seed::<SEED_SIZE>::decode(bytes))
776                .take(decoding_parameter.num_aggregators.into())
777                .collect::<Result<Vec<_>, _>>()?;
778            Ok(Self {
779                joint_rand_parts: Some(joint_rand_parts),
780            })
781        } else {
782            Ok(Self {
783                joint_rand_parts: None,
784            })
785        }
786    }
787}
788
789/// Message sent by the [`Client`] to each [`Aggregator`] during the Sharding phase.
790#[derive(Clone, Debug)]
791pub struct Prio3InputShare<F, const SEED_SIZE: usize> {
792    /// The measurement share.
793    measurement_share: Share<F, SEED_SIZE>,
794
795    /// The proof share.
796    proofs_share: Share<F, SEED_SIZE>,
797
798    /// Blinding seed used by the Aggregator to compute the joint randomness. This field is optional
799    /// because not every [`Type`] requires joint randomness.
800    joint_rand_blind: Option<Seed<SEED_SIZE>>,
801}
802
803impl<F: ConstantTimeEq, const SEED_SIZE: usize> PartialEq for Prio3InputShare<F, SEED_SIZE> {
804    fn eq(&self, other: &Self) -> bool {
805        self.ct_eq(other).into()
806    }
807}
808
809impl<F: ConstantTimeEq, const SEED_SIZE: usize> Eq for Prio3InputShare<F, SEED_SIZE> {}
810
811impl<F: ConstantTimeEq, const SEED_SIZE: usize> ConstantTimeEq for Prio3InputShare<F, SEED_SIZE> {
812    fn ct_eq(&self, other: &Self) -> Choice {
813        // We allow short-circuiting on the presence or absence of the joint_rand_blind.
814        option_ct_eq(
815            self.joint_rand_blind.as_ref(),
816            other.joint_rand_blind.as_ref(),
817        ) & self.measurement_share.ct_eq(&other.measurement_share)
818            & self.proofs_share.ct_eq(&other.proofs_share)
819    }
820}
821
822impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize> Encode for Prio3InputShare<F, SEED_SIZE> {
823    fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
824        if matches!(
825            (&self.measurement_share, &self.proofs_share),
826            (Share::Leader(_), Share::Helper(_)) | (Share::Helper(_), Share::Leader(_))
827        ) {
828            panic!("tried to encode input share with ambiguous encoding")
829        }
830
831        self.measurement_share.encode(bytes)?;
832        self.proofs_share.encode(bytes)?;
833        if let Some(ref blind) = self.joint_rand_blind {
834            blind.encode(bytes)?;
835        }
836        Ok(())
837    }
838
839    fn encoded_len(&self) -> Option<usize> {
840        let mut len = self.measurement_share.encoded_len()? + self.proofs_share.encoded_len()?;
841        if let Some(ref blind) = self.joint_rand_blind {
842            len += blind.encoded_len()?;
843        }
844        Some(len)
845    }
846}
847
848impl<'a, T, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Prio3<T, P, SEED_SIZE>, usize)>
849    for Prio3InputShare<T::Field, SEED_SIZE>
850where
851    T: Type,
852    P: Xof<SEED_SIZE>,
853{
854    fn decode_with_param(
855        (prio3, agg_id): &(&'a Prio3<T, P, SEED_SIZE>, usize),
856        bytes: &mut Cursor<&[u8]>,
857    ) -> Result<Self, CodecError> {
858        let agg_id = prio3
859            .role_try_from(*agg_id)
860            .map_err(|e| CodecError::Other(Box::new(e)))?;
861        let (input_decoder, proof_decoder) = if agg_id == 0 {
862            (
863                ShareDecodingParameter::Leader(prio3.typ.input_len()),
864                ShareDecodingParameter::Leader(prio3.typ.proof_len() * prio3.num_proofs()),
865            )
866        } else {
867            (
868                ShareDecodingParameter::Helper,
869                ShareDecodingParameter::Helper,
870            )
871        };
872
873        let measurement_share = Share::decode_with_param(&input_decoder, bytes)?;
874        let proofs_share = Share::decode_with_param(&proof_decoder, bytes)?;
875        let joint_rand_blind = if prio3.typ.joint_rand_len() > 0 {
876            let blind = Seed::decode(bytes)?;
877            Some(blind)
878        } else {
879            None
880        };
881
882        Ok(Prio3InputShare {
883            measurement_share,
884            proofs_share,
885            joint_rand_blind,
886        })
887    }
888}
889
890#[derive(Clone, Debug)]
891/// Message broadcast by each [`Aggregator`] in each round of the Preparation phase.
892pub struct Prio3PrepareShare<F, const SEED_SIZE: usize> {
893    /// A share of the FLP verifier message. (See [`Type`].)
894    verifiers: Vec<F>,
895
896    /// A part of the joint randomness seed.
897    joint_rand_part: Option<Seed<SEED_SIZE>>,
898}
899
900impl<F: ConstantTimeEq, const SEED_SIZE: usize> PartialEq for Prio3PrepareShare<F, SEED_SIZE> {
901    fn eq(&self, other: &Self) -> bool {
902        self.ct_eq(other).into()
903    }
904}
905
906impl<F: ConstantTimeEq, const SEED_SIZE: usize> Eq for Prio3PrepareShare<F, SEED_SIZE> {}
907
908impl<F: ConstantTimeEq, const SEED_SIZE: usize> ConstantTimeEq for Prio3PrepareShare<F, SEED_SIZE> {
909    fn ct_eq(&self, other: &Self) -> Choice {
910        // We allow short-circuiting on the presence or absence of the joint_rand_part.
911        option_ct_eq(
912            self.joint_rand_part.as_ref(),
913            other.joint_rand_part.as_ref(),
914        ) & self.verifiers.ct_eq(&other.verifiers)
915    }
916}
917
918impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize> Encode
919    for Prio3PrepareShare<F, SEED_SIZE>
920{
921    fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
922        for x in &self.verifiers {
923            x.encode(bytes)?;
924        }
925        if let Some(ref seed) = self.joint_rand_part {
926            seed.encode(bytes)?;
927        }
928        Ok(())
929    }
930
931    fn encoded_len(&self) -> Option<usize> {
932        // Each element of the verifier has the same size.
933        let mut len = F::ENCODED_SIZE * self.verifiers.len();
934        if let Some(ref seed) = self.joint_rand_part {
935            len += seed.encoded_len()?;
936        }
937        Some(len)
938    }
939}
940
941impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize>
942    ParameterizedDecode<Prio3PrepareState<F, SEED_SIZE>> for Prio3PrepareShare<F, SEED_SIZE>
943{
944    fn decode_with_param(
945        decoding_parameter: &Prio3PrepareState<F, SEED_SIZE>,
946        bytes: &mut Cursor<&[u8]>,
947    ) -> Result<Self, CodecError> {
948        let mut verifiers = Vec::with_capacity(decoding_parameter.verifiers_len);
949        for _ in 0..decoding_parameter.verifiers_len {
950            verifiers.push(F::decode(bytes)?);
951        }
952
953        let joint_rand_part = if decoding_parameter.joint_rand_seed.is_some() {
954            Some(Seed::decode(bytes)?)
955        } else {
956            None
957        };
958
959        Ok(Prio3PrepareShare {
960            verifiers,
961            joint_rand_part,
962        })
963    }
964}
965
966#[derive(Clone, Debug)]
967/// Result of combining a round of [`Prio3PrepareShare`] messages.
968pub struct Prio3PrepareMessage<const SEED_SIZE: usize> {
969    /// The joint randomness seed computed by the Aggregators.
970    joint_rand_seed: Option<Seed<SEED_SIZE>>,
971}
972
973impl<const SEED_SIZE: usize> PartialEq for Prio3PrepareMessage<SEED_SIZE> {
974    fn eq(&self, other: &Self) -> bool {
975        self.ct_eq(other).into()
976    }
977}
978
979impl<const SEED_SIZE: usize> Eq for Prio3PrepareMessage<SEED_SIZE> {}
980
981impl<const SEED_SIZE: usize> ConstantTimeEq for Prio3PrepareMessage<SEED_SIZE> {
982    fn ct_eq(&self, other: &Self) -> Choice {
983        // We allow short-circuiting on the presnce or absence of the joint_rand_seed.
984        option_ct_eq(
985            self.joint_rand_seed.as_ref(),
986            other.joint_rand_seed.as_ref(),
987        )
988    }
989}
990
991impl<const SEED_SIZE: usize> Encode for Prio3PrepareMessage<SEED_SIZE> {
992    fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
993        if let Some(ref seed) = self.joint_rand_seed {
994            seed.encode(bytes)?;
995        }
996        Ok(())
997    }
998
999    fn encoded_len(&self) -> Option<usize> {
1000        if let Some(ref seed) = self.joint_rand_seed {
1001            seed.encoded_len()
1002        } else {
1003            Some(0)
1004        }
1005    }
1006}
1007
1008impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize>
1009    ParameterizedDecode<Prio3PrepareState<F, SEED_SIZE>> for Prio3PrepareMessage<SEED_SIZE>
1010{
1011    fn decode_with_param(
1012        decoding_parameter: &Prio3PrepareState<F, SEED_SIZE>,
1013        bytes: &mut Cursor<&[u8]>,
1014    ) -> Result<Self, CodecError> {
1015        let joint_rand_seed = if decoding_parameter.joint_rand_seed.is_some() {
1016            Some(Seed::decode(bytes)?)
1017        } else {
1018            None
1019        };
1020
1021        Ok(Prio3PrepareMessage { joint_rand_seed })
1022    }
1023}
1024
1025impl<T, P, const SEED_SIZE: usize> Client<16> for Prio3<T, P, SEED_SIZE>
1026where
1027    T: Type,
1028    P: Xof<SEED_SIZE>,
1029{
1030    #[allow(clippy::type_complexity)]
1031    fn shard(
1032        &self,
1033        measurement: &T::Measurement,
1034        nonce: &[u8; 16],
1035    ) -> Result<(Self::PublicShare, Vec<Prio3InputShare<T::Field, SEED_SIZE>>), VdafError> {
1036        let mut random = vec![0u8; self.random_size()];
1037        getrandom::getrandom(&mut random)?;
1038        self.shard_with_random(measurement, nonce, &random)
1039    }
1040}
1041
1042/// State of each [`Aggregator`] during the Preparation phase.
1043#[derive(Clone)]
1044pub struct Prio3PrepareState<F, const SEED_SIZE: usize> {
1045    measurement_share: Share<F, SEED_SIZE>,
1046    joint_rand_seed: Option<Seed<SEED_SIZE>>,
1047    agg_id: u8,
1048    verifiers_len: usize,
1049}
1050
1051impl<F: ConstantTimeEq, const SEED_SIZE: usize> PartialEq for Prio3PrepareState<F, SEED_SIZE> {
1052    fn eq(&self, other: &Self) -> bool {
1053        self.ct_eq(other).into()
1054    }
1055}
1056
1057impl<F: ConstantTimeEq, const SEED_SIZE: usize> Eq for Prio3PrepareState<F, SEED_SIZE> {}
1058
1059impl<F: ConstantTimeEq, const SEED_SIZE: usize> ConstantTimeEq for Prio3PrepareState<F, SEED_SIZE> {
1060    fn ct_eq(&self, other: &Self) -> Choice {
1061        // We allow short-circuiting on the presence or absence of the joint_rand_seed, as well as
1062        // the aggregator ID & verifier length parameters.
1063        if self.agg_id != other.agg_id || self.verifiers_len != other.verifiers_len {
1064            return Choice::from(0);
1065        }
1066
1067        option_ct_eq(
1068            self.joint_rand_seed.as_ref(),
1069            other.joint_rand_seed.as_ref(),
1070        ) & self.measurement_share.ct_eq(&other.measurement_share)
1071    }
1072}
1073
1074impl<F, const SEED_SIZE: usize> Debug for Prio3PrepareState<F, SEED_SIZE> {
1075    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1076        f.debug_struct("Prio3PrepareState")
1077            .field("measurement_share", &"[redacted]")
1078            .field(
1079                "joint_rand_seed",
1080                match self.joint_rand_seed {
1081                    Some(_) => &"Some([redacted])",
1082                    None => &"None",
1083                },
1084            )
1085            .field("agg_id", &self.agg_id)
1086            .field("verifiers_len", &self.verifiers_len)
1087            .finish()
1088    }
1089}
1090
1091impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize> Encode
1092    for Prio3PrepareState<F, SEED_SIZE>
1093{
1094    /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed.
1095    fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
1096        self.measurement_share.encode(bytes)?;
1097        if let Some(ref seed) = self.joint_rand_seed {
1098            seed.encode(bytes)?;
1099        }
1100        Ok(())
1101    }
1102
1103    fn encoded_len(&self) -> Option<usize> {
1104        let mut len = self.measurement_share.encoded_len()?;
1105        if let Some(ref seed) = self.joint_rand_seed {
1106            len += seed.encoded_len()?;
1107        }
1108        Some(len)
1109    }
1110}
1111
1112impl<'a, T, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Prio3<T, P, SEED_SIZE>, usize)>
1113    for Prio3PrepareState<T::Field, SEED_SIZE>
1114where
1115    T: Type,
1116    P: Xof<SEED_SIZE>,
1117{
1118    fn decode_with_param(
1119        (prio3, agg_id): &(&'a Prio3<T, P, SEED_SIZE>, usize),
1120        bytes: &mut Cursor<&[u8]>,
1121    ) -> Result<Self, CodecError> {
1122        let agg_id = prio3
1123            .role_try_from(*agg_id)
1124            .map_err(|e| CodecError::Other(Box::new(e)))?;
1125
1126        let share_decoder = if agg_id == 0 {
1127            ShareDecodingParameter::Leader(prio3.typ.input_len())
1128        } else {
1129            ShareDecodingParameter::Helper
1130        };
1131        let measurement_share = Share::decode_with_param(&share_decoder, bytes)?;
1132
1133        let joint_rand_seed = if prio3.typ.joint_rand_len() > 0 {
1134            Some(Seed::decode(bytes)?)
1135        } else {
1136            None
1137        };
1138
1139        Ok(Self {
1140            measurement_share,
1141            joint_rand_seed,
1142            agg_id,
1143            verifiers_len: prio3.typ.verifier_len() * prio3.num_proofs(),
1144        })
1145    }
1146}
1147
1148impl<T, P, const SEED_SIZE: usize> Aggregator<SEED_SIZE, 16> for Prio3<T, P, SEED_SIZE>
1149where
1150    T: Type,
1151    P: Xof<SEED_SIZE>,
1152{
1153    type PrepareState = Prio3PrepareState<T::Field, SEED_SIZE>;
1154    type PrepareShare = Prio3PrepareShare<T::Field, SEED_SIZE>;
1155    type PrepareMessage = Prio3PrepareMessage<SEED_SIZE>;
1156
1157    /// Begins the Prep process with the other aggregators. The result of this process is
1158    /// the aggregator's output share.
1159    #[allow(clippy::type_complexity)]
1160    fn prepare_init(
1161        &self,
1162        verify_key: &[u8; SEED_SIZE],
1163        agg_id: usize,
1164        _agg_param: &Self::AggregationParam,
1165        nonce: &[u8; 16],
1166        public_share: &Self::PublicShare,
1167        msg: &Prio3InputShare<T::Field, SEED_SIZE>,
1168    ) -> Result<
1169        (
1170            Prio3PrepareState<T::Field, SEED_SIZE>,
1171            Prio3PrepareShare<T::Field, SEED_SIZE>,
1172        ),
1173        VdafError,
1174    > {
1175        let agg_id = self.role_try_from(agg_id)?;
1176
1177        let measurement_share = match msg.measurement_share {
1178            Share::Leader(ref data) => Cow::Borrowed(data),
1179            Share::Helper(ref seed) => Cow::Owned(
1180                P::seed_stream(
1181                    seed,
1182                    &self.domain_separation_tag(DST_MEASUREMENT_SHARE),
1183                    &[agg_id],
1184                )
1185                .into_field_vec(self.typ.input_len()),
1186            ),
1187        };
1188
1189        let proofs_share = match msg.proofs_share {
1190            Share::Leader(ref data) => Cow::Borrowed(data),
1191            Share::Helper(ref seed) => Cow::Owned(
1192                self.derive_helper_proofs_share(seed, agg_id)
1193                    .take(self.typ.proof_len() * self.num_proofs())
1194                    .collect::<Vec<_>>(),
1195            ),
1196        };
1197
1198        // Compute the joint randomness.
1199        let (joint_rand_seed, joint_rand_part, joint_rands) = if self.typ.joint_rand_len() > 0 {
1200            let mut joint_rand_part_xof = P::init(
1201                msg.joint_rand_blind.as_ref().unwrap().as_ref(),
1202                &self.domain_separation_tag(DST_JOINT_RAND_PART),
1203            );
1204            joint_rand_part_xof.update(&[agg_id]);
1205            joint_rand_part_xof.update(nonce);
1206            let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE);
1207            for x in measurement_share.iter() {
1208                x.encode(&mut encoding_buffer).map_err(|_| {
1209                    VdafError::Uncategorized("failed to encode measurement share".to_string())
1210                })?;
1211                joint_rand_part_xof.update(&encoding_buffer);
1212                encoding_buffer.clear();
1213            }
1214            let own_joint_rand_part = joint_rand_part_xof.into_seed();
1215
1216            // Make an iterator over the joint randomness parts, but use this aggregator's
1217            // contribution, computed from the input share, in lieu of the the corresponding part
1218            // from the public share.
1219            //
1220            // The locally computed part should match the part from the public share for honestly
1221            // generated reports. If they do not match, the joint randomness seed check during the
1222            // next round of preparation should fail.
1223            let corrected_joint_rand_parts = public_share
1224                .joint_rand_parts
1225                .iter()
1226                .flatten()
1227                .take(agg_id as usize)
1228                .chain(iter::once(&own_joint_rand_part))
1229                .chain(
1230                    public_share
1231                        .joint_rand_parts
1232                        .iter()
1233                        .flatten()
1234                        .skip(agg_id as usize + 1),
1235                );
1236
1237            let (joint_rand_seed, joint_rands) =
1238                self.derive_joint_rands(corrected_joint_rand_parts);
1239
1240            (
1241                Some(joint_rand_seed),
1242                Some(own_joint_rand_part),
1243                joint_rands,
1244            )
1245        } else {
1246            (None, None, Vec::new())
1247        };
1248
1249        // Run the query-generation algorithm.
1250        let query_rands = self.derive_query_rands(verify_key, nonce);
1251        let mut verifiers_share = Vec::with_capacity(self.typ.verifier_len() * self.num_proofs());
1252        for p in 0..self.num_proofs() {
1253            let query_rand =
1254                &query_rands[p * self.typ.query_rand_len()..(p + 1) * self.typ.query_rand_len()];
1255            let joint_rand =
1256                &joint_rands[p * self.typ.joint_rand_len()..(p + 1) * self.typ.joint_rand_len()];
1257            let proof_share =
1258                &proofs_share[p * self.typ.proof_len()..(p + 1) * self.typ.proof_len()];
1259
1260            verifiers_share.append(&mut self.typ.query(
1261                measurement_share.as_ref(),
1262                proof_share,
1263                query_rand,
1264                joint_rand,
1265                self.num_aggregators as usize,
1266            )?);
1267        }
1268
1269        Ok((
1270            Prio3PrepareState {
1271                measurement_share: msg.measurement_share.clone(),
1272                joint_rand_seed,
1273                agg_id,
1274                verifiers_len: verifiers_share.len(),
1275            },
1276            Prio3PrepareShare {
1277                verifiers: verifiers_share,
1278                joint_rand_part,
1279            },
1280        ))
1281    }
1282
1283    fn prepare_shares_to_prepare_message<
1284        M: IntoIterator<Item = Prio3PrepareShare<T::Field, SEED_SIZE>>,
1285    >(
1286        &self,
1287        _: &Self::AggregationParam,
1288        inputs: M,
1289    ) -> Result<Prio3PrepareMessage<SEED_SIZE>, VdafError> {
1290        let mut verifiers = vec![T::Field::zero(); self.typ.verifier_len() * self.num_proofs()];
1291        let mut joint_rand_parts = Vec::with_capacity(self.num_aggregators());
1292        let mut count = 0;
1293        for share in inputs.into_iter() {
1294            count += 1;
1295
1296            if share.verifiers.len() != verifiers.len() {
1297                return Err(VdafError::Uncategorized(format!(
1298                    "unexpected verifier share length: got {}; want {}",
1299                    share.verifiers.len(),
1300                    verifiers.len(),
1301                )));
1302            }
1303
1304            if self.typ.joint_rand_len() > 0 {
1305                let joint_rand_seed_part = share.joint_rand_part.unwrap();
1306                joint_rand_parts.push(joint_rand_seed_part);
1307            }
1308
1309            for (x, y) in verifiers.iter_mut().zip(share.verifiers) {
1310                *x += y;
1311            }
1312        }
1313
1314        if count != self.num_aggregators {
1315            return Err(VdafError::Uncategorized(format!(
1316                "unexpected message count: got {}; want {}",
1317                count, self.num_aggregators,
1318            )));
1319        }
1320
1321        // Check the proof verifiers.
1322        for verifier in verifiers.chunks(self.typ.verifier_len()) {
1323            if !self.typ.decide(verifier)? {
1324                return Err(VdafError::Uncategorized(
1325                    "proof verifier check failed".into(),
1326                ));
1327            }
1328        }
1329
1330        let joint_rand_seed = if self.typ.joint_rand_len() > 0 {
1331            Some(self.derive_joint_rand_seed(joint_rand_parts.iter()))
1332        } else {
1333            None
1334        };
1335
1336        Ok(Prio3PrepareMessage { joint_rand_seed })
1337    }
1338
1339    fn prepare_next(
1340        &self,
1341        step: Prio3PrepareState<T::Field, SEED_SIZE>,
1342        msg: Prio3PrepareMessage<SEED_SIZE>,
1343    ) -> Result<PrepareTransition<Self, SEED_SIZE, 16>, VdafError> {
1344        if self.typ.joint_rand_len() > 0 {
1345            // Check that the joint randomness was correct.
1346            if step
1347                .joint_rand_seed
1348                .as_ref()
1349                .unwrap()
1350                .ct_ne(msg.joint_rand_seed.as_ref().unwrap())
1351                .into()
1352            {
1353                return Err(VdafError::Uncategorized(
1354                    "joint randomness mismatch".to_string(),
1355                ));
1356            }
1357        }
1358
1359        // Compute the output share.
1360        let measurement_share = match step.measurement_share {
1361            Share::Leader(data) => data,
1362            Share::Helper(seed) => {
1363                let dst = self.domain_separation_tag(DST_MEASUREMENT_SHARE);
1364                P::seed_stream(&seed, &dst, &[step.agg_id]).into_field_vec(self.typ.input_len())
1365            }
1366        };
1367
1368        let output_share = match self.typ.truncate(measurement_share) {
1369            Ok(data) => OutputShare(data),
1370            Err(err) => {
1371                return Err(VdafError::from(err));
1372            }
1373        };
1374
1375        Ok(PrepareTransition::Finish(output_share))
1376    }
1377
1378    /// Aggregates a sequence of output shares into an aggregate share.
1379    fn aggregate<It: IntoIterator<Item = OutputShare<T::Field>>>(
1380        &self,
1381        _agg_param: &(),
1382        output_shares: It,
1383    ) -> Result<AggregateShare<T::Field>, VdafError> {
1384        let mut agg_share = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]);
1385        for output_share in output_shares.into_iter() {
1386            agg_share.accumulate(&output_share)?;
1387        }
1388
1389        Ok(agg_share)
1390    }
1391}
1392
1393#[cfg(feature = "experimental")]
1394impl<T, P, S, const SEED_SIZE: usize> AggregatorWithNoise<SEED_SIZE, 16, S>
1395    for Prio3<T, P, SEED_SIZE>
1396where
1397    T: TypeWithNoise<S>,
1398    P: Xof<SEED_SIZE>,
1399    S: DifferentialPrivacyStrategy,
1400{
1401    fn add_noise_to_agg_share(
1402        &self,
1403        dp_strategy: &S,
1404        _agg_param: &Self::AggregationParam,
1405        agg_share: &mut Self::AggregateShare,
1406        num_measurements: usize,
1407    ) -> Result<(), VdafError> {
1408        self.typ
1409            .add_noise_to_result(dp_strategy, &mut agg_share.0, num_measurements)?;
1410        Ok(())
1411    }
1412}
1413
1414impl<T, P, const SEED_SIZE: usize> Collector for Prio3<T, P, SEED_SIZE>
1415where
1416    T: Type,
1417    P: Xof<SEED_SIZE>,
1418{
1419    /// Combines aggregate shares into the aggregate result.
1420    fn unshard<It: IntoIterator<Item = AggregateShare<T::Field>>>(
1421        &self,
1422        _agg_param: &Self::AggregationParam,
1423        agg_shares: It,
1424        num_measurements: usize,
1425    ) -> Result<T::AggregateResult, VdafError> {
1426        let mut agg = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]);
1427        for agg_share in agg_shares.into_iter() {
1428            agg.merge(&agg_share)?;
1429        }
1430
1431        Ok(self.typ.decode_result(&agg.0, num_measurements)?)
1432    }
1433}
1434
1435#[derive(Clone)]
1436struct HelperShare<const SEED_SIZE: usize> {
1437    measurement_share: Seed<SEED_SIZE>,
1438    proofs_share: Seed<SEED_SIZE>,
1439    joint_rand_blind: Option<Seed<SEED_SIZE>>,
1440}
1441
1442impl<const SEED_SIZE: usize> HelperShare<SEED_SIZE> {
1443    fn from_seeds(
1444        measurement_share: [u8; SEED_SIZE],
1445        proof_share: [u8; SEED_SIZE],
1446        joint_rand_blind: Option<[u8; SEED_SIZE]>,
1447    ) -> Self {
1448        HelperShare {
1449            measurement_share: Seed::from_bytes(measurement_share),
1450            proofs_share: Seed::from_bytes(proof_share),
1451            joint_rand_blind: joint_rand_blind.map(Seed::from_bytes),
1452        }
1453    }
1454}
1455
1456fn check_num_aggregators(num_aggregators: u8) -> Result<(), VdafError> {
1457    if num_aggregators == 0 {
1458        return Err(VdafError::Uncategorized(format!(
1459            "at least one aggregator is required; got {num_aggregators}"
1460        )));
1461    } else if num_aggregators > 254 {
1462        return Err(VdafError::Uncategorized(format!(
1463            "number of aggregators must not exceed 254; got {num_aggregators}"
1464        )));
1465    }
1466
1467    Ok(())
1468}
1469
1470impl<'a, F, T, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Prio3<T, P, SEED_SIZE>, &'a ())>
1471    for OutputShare<F>
1472where
1473    F: FieldElement,
1474    T: Type,
1475    P: Xof<SEED_SIZE>,
1476{
1477    fn decode_with_param(
1478        (vdaf, _): &(&'a Prio3<T, P, SEED_SIZE>, &'a ()),
1479        bytes: &mut Cursor<&[u8]>,
1480    ) -> Result<Self, CodecError> {
1481        decode_fieldvec(vdaf.output_len(), bytes).map(Self)
1482    }
1483}
1484
1485impl<'a, F, T, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Prio3<T, P, SEED_SIZE>, &'a ())>
1486    for AggregateShare<F>
1487where
1488    F: FieldElement,
1489    T: Type,
1490    P: Xof<SEED_SIZE>,
1491{
1492    fn decode_with_param(
1493        (vdaf, _): &(&'a Prio3<T, P, SEED_SIZE>, &'a ()),
1494        bytes: &mut Cursor<&[u8]>,
1495    ) -> Result<Self, CodecError> {
1496        decode_fieldvec(vdaf.output_len(), bytes).map(Self)
1497    }
1498}
1499
1500/// This function determines equality between two optional, constant-time comparable values. It
1501/// short-circuits on the existence (but not contents) of the values -- a timing side-channel may
1502/// reveal whether the values match on Some or None.
1503#[inline]
1504fn option_ct_eq<T>(left: Option<&T>, right: Option<&T>) -> Choice
1505where
1506    T: ConstantTimeEq + ?Sized,
1507{
1508    match (left, right) {
1509        (Some(left), Some(right)) => left.ct_eq(right),
1510        (None, None) => Choice::from(1),
1511        _ => Choice::from(0),
1512    }
1513}
1514
1515/// This is a polyfill for `usize::ilog2()`, which is only available in Rust 1.67 and later. It is
1516/// based on the implementation in the standard library. It can be removed when the MSRV has been
1517/// advanced past 1.67.
1518///
1519/// # Panics
1520///
1521/// This function will panic if `input` is zero.
1522fn ilog2(input: usize) -> u32 {
1523    if input == 0 {
1524        panic!("Tried to take the logarithm of zero");
1525    }
1526    (usize::BITS - 1) - input.leading_zeros()
1527}
1528
1529/// Finds the optimal choice of chunk length for [`Prio3Histogram`] or [`Prio3SumVec`], given its
1530/// encoded measurement length. For [`Prio3Histogram`], the measurement length is equal to the
1531/// length parameter. For [`Prio3SumVec`], the measurement length is equal to the product of the
1532/// length and bits parameters.
1533pub fn optimal_chunk_length(measurement_length: usize) -> usize {
1534    if measurement_length <= 1 {
1535        return 1;
1536    }
1537
1538    /// Candidate set of parameter choices for the parallel sum optimization.
1539    struct Candidate {
1540        gadget_calls: usize,
1541        chunk_length: usize,
1542    }
1543
1544    let max_log2 = ilog2(measurement_length + 1);
1545    let best_opt = (1..=max_log2)
1546        .rev()
1547        .map(|log2| {
1548            let gadget_calls = (1 << log2) - 1;
1549            let chunk_length = (measurement_length + gadget_calls - 1) / gadget_calls;
1550            Candidate {
1551                gadget_calls,
1552                chunk_length,
1553            }
1554        })
1555        .min_by_key(|candidate| {
1556            // Compute the proof length, in field elements, for either Prio3Histogram or Prio3SumVec
1557            (candidate.chunk_length * 2)
1558                + 2 * ((1 + candidate.gadget_calls).next_power_of_two() - 1)
1559        });
1560    // Unwrap safety: max_log2 must be at least 1, because smaller measurement_length inputs are
1561    // dealt with separately. Thus, the range iterator that the search is over will be nonempty,
1562    // and min_by_key() will always return Some.
1563    best_opt.unwrap().chunk_length
1564}
1565
1566#[cfg(test)]
1567mod tests {
1568    use super::*;
1569    #[cfg(feature = "experimental")]
1570    use crate::flp::gadgets::ParallelSumGadget;
1571    use crate::vdaf::{
1572        equality_comparison_test, fieldvec_roundtrip_test,
1573        test_utils::{run_vdaf, run_vdaf_prepare},
1574    };
1575    use assert_matches::assert_matches;
1576    #[cfg(feature = "experimental")]
1577    use fixed::{
1578        types::extra::{U15, U31, U63},
1579        FixedI16, FixedI32, FixedI64,
1580    };
1581    #[cfg(feature = "experimental")]
1582    use fixed_macro::fixed;
1583    use rand::prelude::*;
1584
1585    #[test]
1586    fn test_prio3_count() {
1587        let prio3 = Prio3::new_count(2).unwrap();
1588
1589        assert_eq!(
1590            run_vdaf(&prio3, &(), [true, false, false, true, true]).unwrap(),
1591            3
1592        );
1593
1594        let mut nonce = [0; 16];
1595        let mut verify_key = [0; 16];
1596        thread_rng().fill(&mut verify_key[..]);
1597        thread_rng().fill(&mut nonce[..]);
1598
1599        let (public_share, input_shares) = prio3.shard(&false, &nonce).unwrap();
1600        run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares).unwrap();
1601
1602        let (public_share, input_shares) = prio3.shard(&true, &nonce).unwrap();
1603        run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares).unwrap();
1604
1605        test_serialization(&prio3, &true, &nonce).unwrap();
1606
1607        let prio3_extra_helper = Prio3::new_count(3).unwrap();
1608        assert_eq!(
1609            run_vdaf(&prio3_extra_helper, &(), [true, false, false, true, true]).unwrap(),
1610            3,
1611        );
1612    }
1613
1614    #[test]
1615    fn test_prio3_sum() {
1616        let prio3 = Prio3::new_sum(3, 16).unwrap();
1617
1618        assert_eq!(
1619            run_vdaf(&prio3, &(), [0, (1 << 16) - 1, 0, 1, 1]).unwrap(),
1620            (1 << 16) + 1
1621        );
1622
1623        let mut verify_key = [0; 16];
1624        thread_rng().fill(&mut verify_key[..]);
1625        let nonce = [0; 16];
1626
1627        let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap();
1628        input_shares[0].joint_rand_blind.as_mut().unwrap().0[0] ^= 255;
1629        let result = run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares);
1630        assert_matches!(result, Err(VdafError::Uncategorized(_)));
1631
1632        let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap();
1633        assert_matches!(input_shares[0].measurement_share, Share::Leader(ref mut data) => {
1634            data[0] += Field128::one();
1635        });
1636        let result = run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares);
1637        assert_matches!(result, Err(VdafError::Uncategorized(_)));
1638
1639        let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap();
1640        assert_matches!(input_shares[0].proofs_share, Share::Leader(ref mut data) => {
1641                data[0] += Field128::one();
1642        });
1643        let result = run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares);
1644        assert_matches!(result, Err(VdafError::Uncategorized(_)));
1645
1646        test_serialization(&prio3, &1, &nonce).unwrap();
1647    }
1648
1649    #[test]
1650    fn test_prio3_sum_vec() {
1651        let prio3 = Prio3::new_sum_vec(2, 2, 20, 4).unwrap();
1652        assert_eq!(
1653            run_vdaf(
1654                &prio3,
1655                &(),
1656                [
1657                    vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1],
1658                    vec![0, 2, 0, 0, 1, 0, 0, 0, 1, 1, 1, 3, 0, 3, 0, 0, 0, 1, 0, 0],
1659                    vec![1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1],
1660                ]
1661            )
1662            .unwrap(),
1663            vec![1, 3, 1, 0, 3, 1, 0, 1, 2, 2, 3, 3, 1, 5, 1, 2, 1, 3, 0, 2],
1664        );
1665    }
1666
1667    #[test]
1668    fn test_prio3_sum_vec_multiproof() {
1669        let prio3 = Prio3::<
1670            SumVec<Field128, ParallelSum<Field128, Mul<Field128>>>,
1671            XofTurboShake128,
1672            16,
1673        >::new(2, 2, 0xFFFF0000, SumVec::new(2, 20, 4).unwrap())
1674        .unwrap();
1675
1676        assert_eq!(
1677            run_vdaf(
1678                &prio3,
1679                &(),
1680                [
1681                    vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1],
1682                    vec![0, 2, 0, 0, 1, 0, 0, 0, 1, 1, 1, 3, 0, 3, 0, 0, 0, 1, 0, 0],
1683                    vec![1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1],
1684                ]
1685            )
1686            .unwrap(),
1687            vec![1, 3, 1, 0, 3, 1, 0, 1, 2, 2, 3, 3, 1, 5, 1, 2, 1, 3, 0, 2],
1688        );
1689    }
1690
1691    #[test]
1692    #[cfg(feature = "multithreaded")]
1693    fn test_prio3_sum_vec_multithreaded() {
1694        let prio3 = Prio3::new_sum_vec_multithreaded(2, 2, 20, 4).unwrap();
1695        assert_eq!(
1696            run_vdaf(
1697                &prio3,
1698                &(),
1699                [
1700                    vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1],
1701                    vec![0, 2, 0, 0, 1, 0, 0, 0, 1, 1, 1, 3, 0, 3, 0, 0, 0, 1, 0, 0],
1702                    vec![1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1],
1703                ]
1704            )
1705            .unwrap(),
1706            vec![1, 3, 1, 0, 3, 1, 0, 1, 2, 2, 3, 3, 1, 5, 1, 2, 1, 3, 0, 2],
1707        );
1708    }
1709
1710    #[test]
1711    #[cfg(feature = "experimental")]
1712    fn test_prio3_bounded_fpvec_sum_unaligned() {
1713        type P<Fx> = Prio3FixedPointBoundedL2VecSum<Fx>;
1714        #[cfg(feature = "multithreaded")]
1715        type PM<Fx> = Prio3FixedPointBoundedL2VecSumMultithreaded<Fx>;
1716        let ctor_32 = P::<FixedI32<U31>>::new_fixedpoint_boundedl2_vec_sum;
1717        #[cfg(feature = "multithreaded")]
1718        let ctor_mt_32 = PM::<FixedI32<U31>>::new_fixedpoint_boundedl2_vec_sum_multithreaded;
1719
1720        {
1721            const SIZE: usize = 5;
1722            let fp32_0 = fixed!(0: I1F31);
1723
1724            // 32 bit fixedpoint, non-power-of-2 vector, single-threaded
1725            {
1726                let prio3_32 = ctor_32(2, SIZE).unwrap();
1727                test_fixed_vec::<_, _, _, SIZE>(fp32_0, prio3_32);
1728            }
1729
1730            // 32 bit fixedpoint, non-power-of-2 vector, multi-threaded
1731            #[cfg(feature = "multithreaded")]
1732            {
1733                let prio3_mt_32 = ctor_mt_32(2, SIZE).unwrap();
1734                test_fixed_vec::<_, _, _, SIZE>(fp32_0, prio3_mt_32);
1735            }
1736        }
1737
1738        fn test_fixed_vec<Fx, PE, M, const SIZE: usize>(
1739            fp_0: Fx,
1740            prio3: Prio3<FixedPointBoundedL2VecSum<Fx, PE, M>, XofTurboShake128, 16>,
1741        ) where
1742            Fx: Fixed + CompatibleFloat + std::ops::Neg<Output = Fx>,
1743            PE: Eq + ParallelSumGadget<Field128, PolyEval<Field128>> + Clone + 'static,
1744            M: Eq + ParallelSumGadget<Field128, Mul<Field128>> + Clone + 'static,
1745        {
1746            let fp_vec = vec![fp_0; SIZE];
1747
1748            let measurements = [fp_vec.clone(), fp_vec];
1749            assert_eq!(
1750                run_vdaf(&prio3, &(), measurements).unwrap(),
1751                vec![0.0; SIZE]
1752            );
1753        }
1754    }
1755
1756    #[test]
1757    #[cfg(feature = "experimental")]
1758    fn test_prio3_bounded_fpvec_sum() {
1759        type P<Fx> = Prio3FixedPointBoundedL2VecSum<Fx>;
1760        let ctor_16 = P::<FixedI16<U15>>::new_fixedpoint_boundedl2_vec_sum;
1761        let ctor_32 = P::<FixedI32<U31>>::new_fixedpoint_boundedl2_vec_sum;
1762        let ctor_64 = P::<FixedI64<U63>>::new_fixedpoint_boundedl2_vec_sum;
1763
1764        #[cfg(feature = "multithreaded")]
1765        type PM<Fx> = Prio3FixedPointBoundedL2VecSumMultithreaded<Fx>;
1766        #[cfg(feature = "multithreaded")]
1767        let ctor_mt_16 = PM::<FixedI16<U15>>::new_fixedpoint_boundedl2_vec_sum_multithreaded;
1768        #[cfg(feature = "multithreaded")]
1769        let ctor_mt_32 = PM::<FixedI32<U31>>::new_fixedpoint_boundedl2_vec_sum_multithreaded;
1770        #[cfg(feature = "multithreaded")]
1771        let ctor_mt_64 = PM::<FixedI64<U63>>::new_fixedpoint_boundedl2_vec_sum_multithreaded;
1772
1773        {
1774            // 16 bit fixedpoint
1775            let fp16_4_inv = fixed!(0.25: I1F15);
1776            let fp16_8_inv = fixed!(0.125: I1F15);
1777            let fp16_16_inv = fixed!(0.0625: I1F15);
1778
1779            // two aggregators, three entries per vector.
1780            {
1781                let prio3_16 = ctor_16(2, 3).unwrap();
1782                test_fixed(fp16_4_inv, fp16_8_inv, fp16_16_inv, prio3_16);
1783            }
1784
1785            #[cfg(feature = "multithreaded")]
1786            {
1787                let prio3_16_mt = ctor_mt_16(2, 3).unwrap();
1788                test_fixed(fp16_4_inv, fp16_8_inv, fp16_16_inv, prio3_16_mt);
1789            }
1790        }
1791
1792        {
1793            // 32 bit fixedpoint
1794            let fp32_4_inv = fixed!(0.25: I1F31);
1795            let fp32_8_inv = fixed!(0.125: I1F31);
1796            let fp32_16_inv = fixed!(0.0625: I1F31);
1797
1798            {
1799                let prio3_32 = ctor_32(2, 3).unwrap();
1800                test_fixed(fp32_4_inv, fp32_8_inv, fp32_16_inv, prio3_32);
1801            }
1802
1803            #[cfg(feature = "multithreaded")]
1804            {
1805                let prio3_32_mt = ctor_mt_32(2, 3).unwrap();
1806                test_fixed(fp32_4_inv, fp32_8_inv, fp32_16_inv, prio3_32_mt);
1807            }
1808        }
1809
1810        {
1811            // 64 bit fixedpoint
1812            let fp64_4_inv = fixed!(0.25: I1F63);
1813            let fp64_8_inv = fixed!(0.125: I1F63);
1814            let fp64_16_inv = fixed!(0.0625: I1F63);
1815
1816            {
1817                let prio3_64 = ctor_64(2, 3).unwrap();
1818                test_fixed(fp64_4_inv, fp64_8_inv, fp64_16_inv, prio3_64);
1819            }
1820
1821            #[cfg(feature = "multithreaded")]
1822            {
1823                let prio3_64_mt = ctor_mt_64(2, 3).unwrap();
1824                test_fixed(fp64_4_inv, fp64_8_inv, fp64_16_inv, prio3_64_mt);
1825            }
1826        }
1827
1828        fn test_fixed<Fx, PE, M>(
1829            fp_4_inv: Fx,
1830            fp_8_inv: Fx,
1831            fp_16_inv: Fx,
1832            prio3: Prio3<FixedPointBoundedL2VecSum<Fx, PE, M>, XofTurboShake128, 16>,
1833        ) where
1834            Fx: Fixed + CompatibleFloat + std::ops::Neg<Output = Fx>,
1835            PE: Eq + ParallelSumGadget<Field128, PolyEval<Field128>> + Clone + 'static,
1836            M: Eq + ParallelSumGadget<Field128, Mul<Field128>> + Clone + 'static,
1837        {
1838            let fp_vec1 = vec![fp_4_inv, fp_8_inv, fp_16_inv];
1839            let fp_vec2 = vec![fp_4_inv, fp_8_inv, fp_16_inv];
1840
1841            let fp_vec3 = vec![-fp_4_inv, -fp_8_inv, -fp_16_inv];
1842            let fp_vec4 = vec![-fp_4_inv, -fp_8_inv, -fp_16_inv];
1843
1844            let fp_vec5 = vec![fp_4_inv, -fp_8_inv, -fp_16_inv];
1845            let fp_vec6 = vec![fp_4_inv, fp_8_inv, fp_16_inv];
1846
1847            // positive entries
1848            let fp_list = [fp_vec1, fp_vec2];
1849            assert_eq!(
1850                run_vdaf(&prio3, &(), fp_list).unwrap(),
1851                vec!(0.5, 0.25, 0.125),
1852            );
1853
1854            // negative entries
1855            let fp_list2 = [fp_vec3, fp_vec4];
1856            assert_eq!(
1857                run_vdaf(&prio3, &(), fp_list2).unwrap(),
1858                vec!(-0.5, -0.25, -0.125),
1859            );
1860
1861            // both
1862            let fp_list3 = [fp_vec5, fp_vec6];
1863            assert_eq!(
1864                run_vdaf(&prio3, &(), fp_list3).unwrap(),
1865                vec!(0.5, 0.0, 0.0),
1866            );
1867
1868            let mut verify_key = [0; 16];
1869            let mut nonce = [0; 16];
1870            thread_rng().fill(&mut verify_key);
1871            thread_rng().fill(&mut nonce);
1872
1873            let (public_share, mut input_shares) = prio3
1874                .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce)
1875                .unwrap();
1876            input_shares[0].joint_rand_blind.as_mut().unwrap().0[0] ^= 255;
1877            let result =
1878                run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares);
1879            assert_matches!(result, Err(VdafError::Uncategorized(_)));
1880
1881            let (public_share, mut input_shares) = prio3
1882                .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce)
1883                .unwrap();
1884            assert_matches!(input_shares[0].measurement_share, Share::Leader(ref mut data) => {
1885                data[0] += Field128::one();
1886            });
1887            let result =
1888                run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares);
1889            assert_matches!(result, Err(VdafError::Uncategorized(_)));
1890
1891            let (public_share, mut input_shares) = prio3
1892                .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce)
1893                .unwrap();
1894            assert_matches!(input_shares[0].proofs_share, Share::Leader(ref mut data) => {
1895                    data[0] += Field128::one();
1896            });
1897            let result =
1898                run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares);
1899            assert_matches!(result, Err(VdafError::Uncategorized(_)));
1900
1901            test_serialization(&prio3, &vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce).unwrap();
1902        }
1903    }
1904
1905    #[test]
1906    fn test_prio3_histogram() {
1907        let prio3 = Prio3::new_histogram(2, 4, 2).unwrap();
1908
1909        assert_eq!(
1910            run_vdaf(&prio3, &(), [0, 1, 2, 3]).unwrap(),
1911            vec![1, 1, 1, 1]
1912        );
1913        assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]);
1914        assert_eq!(run_vdaf(&prio3, &(), [1]).unwrap(), vec![0, 1, 0, 0]);
1915        assert_eq!(run_vdaf(&prio3, &(), [2]).unwrap(), vec![0, 0, 1, 0]);
1916        assert_eq!(run_vdaf(&prio3, &(), [3]).unwrap(), vec![0, 0, 0, 1]);
1917        test_serialization(&prio3, &3, &[0; 16]).unwrap();
1918    }
1919
1920    #[test]
1921    #[cfg(feature = "multithreaded")]
1922    fn test_prio3_histogram_multithreaded() {
1923        let prio3 = Prio3::new_histogram_multithreaded(2, 4, 2).unwrap();
1924
1925        assert_eq!(
1926            run_vdaf(&prio3, &(), [0, 1, 2, 3]).unwrap(),
1927            vec![1, 1, 1, 1]
1928        );
1929        assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]);
1930        assert_eq!(run_vdaf(&prio3, &(), [1]).unwrap(), vec![0, 1, 0, 0]);
1931        assert_eq!(run_vdaf(&prio3, &(), [2]).unwrap(), vec![0, 0, 1, 0]);
1932        assert_eq!(run_vdaf(&prio3, &(), [3]).unwrap(), vec![0, 0, 0, 1]);
1933        test_serialization(&prio3, &3, &[0; 16]).unwrap();
1934    }
1935
1936    #[test]
1937    fn test_prio3_average() {
1938        let prio3 = Prio3::new_average(2, 64).unwrap();
1939
1940        assert_eq!(run_vdaf(&prio3, &(), [17, 8]).unwrap(), 12.5f64);
1941        assert_eq!(run_vdaf(&prio3, &(), [1, 1, 1, 1]).unwrap(), 1f64);
1942        assert_eq!(run_vdaf(&prio3, &(), [0, 0, 0, 1]).unwrap(), 0.25f64);
1943        assert_eq!(
1944            run_vdaf(&prio3, &(), [1, 11, 111, 1111, 3, 8]).unwrap(),
1945            207.5f64
1946        );
1947    }
1948
1949    #[test]
1950    fn test_prio3_input_share() {
1951        let prio3 = Prio3::new_sum(5, 16).unwrap();
1952        let (_public_share, input_shares) = prio3.shard(&1, &[0; 16]).unwrap();
1953
1954        // Check that seed shares are distinct.
1955        for (i, x) in input_shares.iter().enumerate() {
1956            for (j, y) in input_shares.iter().enumerate() {
1957                if i != j {
1958                    if let (Share::Helper(left), Share::Helper(right)) =
1959                        (&x.measurement_share, &y.measurement_share)
1960                    {
1961                        assert_ne!(left, right);
1962                    }
1963
1964                    if let (Share::Helper(left), Share::Helper(right)) =
1965                        (&x.proofs_share, &y.proofs_share)
1966                    {
1967                        assert_ne!(left, right);
1968                    }
1969
1970                    assert_ne!(x.joint_rand_blind, y.joint_rand_blind);
1971                }
1972            }
1973        }
1974    }
1975
1976    fn test_serialization<T, P, const SEED_SIZE: usize>(
1977        prio3: &Prio3<T, P, SEED_SIZE>,
1978        measurement: &T::Measurement,
1979        nonce: &[u8; 16],
1980    ) -> Result<(), VdafError>
1981    where
1982        T: Type,
1983        P: Xof<SEED_SIZE>,
1984    {
1985        let mut verify_key = [0; SEED_SIZE];
1986        thread_rng().fill(&mut verify_key[..]);
1987        let (public_share, input_shares) = prio3.shard(measurement, nonce)?;
1988
1989        let encoded_public_share = public_share.get_encoded().unwrap();
1990        let decoded_public_share =
1991            Prio3PublicShare::get_decoded_with_param(prio3, &encoded_public_share)
1992                .expect("failed to decode public share");
1993        assert_eq!(decoded_public_share, public_share);
1994        assert_eq!(
1995            public_share.encoded_len().unwrap(),
1996            encoded_public_share.len()
1997        );
1998
1999        for (agg_id, input_share) in input_shares.iter().enumerate() {
2000            let encoded_input_share = input_share.get_encoded().unwrap();
2001            let decoded_input_share =
2002                Prio3InputShare::get_decoded_with_param(&(prio3, agg_id), &encoded_input_share)
2003                    .expect("failed to decode input share");
2004            assert_eq!(&decoded_input_share, input_share);
2005            assert_eq!(
2006                input_share.encoded_len().unwrap(),
2007                encoded_input_share.len()
2008            );
2009        }
2010
2011        let mut prepare_shares = Vec::new();
2012        let mut last_prepare_state = None;
2013        for (agg_id, input_share) in input_shares.iter().enumerate() {
2014            let (prepare_state, prepare_share) =
2015                prio3.prepare_init(&verify_key, agg_id, &(), nonce, &public_share, input_share)?;
2016
2017            let encoded_prepare_state = prepare_state.get_encoded().unwrap();
2018            let decoded_prepare_state =
2019                Prio3PrepareState::get_decoded_with_param(&(prio3, agg_id), &encoded_prepare_state)
2020                    .expect("failed to decode prepare state");
2021            assert_eq!(decoded_prepare_state, prepare_state);
2022            assert_eq!(
2023                prepare_state.encoded_len().unwrap(),
2024                encoded_prepare_state.len()
2025            );
2026
2027            let encoded_prepare_share = prepare_share.get_encoded().unwrap();
2028            let decoded_prepare_share =
2029                Prio3PrepareShare::get_decoded_with_param(&prepare_state, &encoded_prepare_share)
2030                    .expect("failed to decode prepare share");
2031            assert_eq!(decoded_prepare_share, prepare_share);
2032            assert_eq!(
2033                prepare_share.encoded_len().unwrap(),
2034                encoded_prepare_share.len()
2035            );
2036
2037            prepare_shares.push(prepare_share);
2038            last_prepare_state = Some(prepare_state);
2039        }
2040
2041        let prepare_message = prio3
2042            .prepare_shares_to_prepare_message(&(), prepare_shares)
2043            .unwrap();
2044
2045        let encoded_prepare_message = prepare_message.get_encoded().unwrap();
2046        let decoded_prepare_message = Prio3PrepareMessage::get_decoded_with_param(
2047            &last_prepare_state.unwrap(),
2048            &encoded_prepare_message,
2049        )
2050        .expect("failed to decode prepare message");
2051        assert_eq!(decoded_prepare_message, prepare_message);
2052        assert_eq!(
2053            prepare_message.encoded_len().unwrap(),
2054            encoded_prepare_message.len()
2055        );
2056
2057        Ok(())
2058    }
2059
2060    #[test]
2061    fn roundtrip_output_share() {
2062        let vdaf = Prio3::new_count(2).unwrap();
2063        fieldvec_roundtrip_test::<Field64, Prio3Count, OutputShare<Field64>>(&vdaf, &(), 1);
2064
2065        let vdaf = Prio3::new_sum(2, 17).unwrap();
2066        fieldvec_roundtrip_test::<Field128, Prio3Sum, OutputShare<Field128>>(&vdaf, &(), 1);
2067
2068        let vdaf = Prio3::new_histogram(2, 12, 3).unwrap();
2069        fieldvec_roundtrip_test::<Field128, Prio3Histogram, OutputShare<Field128>>(&vdaf, &(), 12);
2070    }
2071
2072    #[test]
2073    fn roundtrip_aggregate_share() {
2074        let vdaf = Prio3::new_count(2).unwrap();
2075        fieldvec_roundtrip_test::<Field64, Prio3Count, AggregateShare<Field64>>(&vdaf, &(), 1);
2076
2077        let vdaf = Prio3::new_sum(2, 17).unwrap();
2078        fieldvec_roundtrip_test::<Field128, Prio3Sum, AggregateShare<Field128>>(&vdaf, &(), 1);
2079
2080        let vdaf = Prio3::new_histogram(2, 12, 3).unwrap();
2081        fieldvec_roundtrip_test::<Field128, Prio3Histogram, AggregateShare<Field128>>(
2082            &vdaf,
2083            &(),
2084            12,
2085        );
2086    }
2087
2088    #[test]
2089    fn public_share_equality_test() {
2090        equality_comparison_test(&[
2091            Prio3PublicShare {
2092                joint_rand_parts: Some(Vec::from([Seed([0])])),
2093            },
2094            Prio3PublicShare {
2095                joint_rand_parts: Some(Vec::from([Seed([1])])),
2096            },
2097            Prio3PublicShare {
2098                joint_rand_parts: None,
2099            },
2100        ])
2101    }
2102
2103    #[test]
2104    fn input_share_equality_test() {
2105        equality_comparison_test(&[
2106            // Default.
2107            Prio3InputShare {
2108                measurement_share: Share::Leader(Vec::from([0])),
2109                proofs_share: Share::Leader(Vec::from([1])),
2110                joint_rand_blind: Some(Seed([2])),
2111            },
2112            // Modified measurement share.
2113            Prio3InputShare {
2114                measurement_share: Share::Leader(Vec::from([100])),
2115                proofs_share: Share::Leader(Vec::from([1])),
2116                joint_rand_blind: Some(Seed([2])),
2117            },
2118            // Modified proof share.
2119            Prio3InputShare {
2120                measurement_share: Share::Leader(Vec::from([0])),
2121                proofs_share: Share::Leader(Vec::from([101])),
2122                joint_rand_blind: Some(Seed([2])),
2123            },
2124            // Modified joint_rand_blind.
2125            Prio3InputShare {
2126                measurement_share: Share::Leader(Vec::from([0])),
2127                proofs_share: Share::Leader(Vec::from([1])),
2128                joint_rand_blind: Some(Seed([102])),
2129            },
2130            // Missing joint_rand_blind.
2131            Prio3InputShare {
2132                measurement_share: Share::Leader(Vec::from([0])),
2133                proofs_share: Share::Leader(Vec::from([1])),
2134                joint_rand_blind: None,
2135            },
2136        ])
2137    }
2138
2139    #[test]
2140    fn prepare_share_equality_test() {
2141        equality_comparison_test(&[
2142            // Default.
2143            Prio3PrepareShare {
2144                verifiers: Vec::from([0]),
2145                joint_rand_part: Some(Seed([1])),
2146            },
2147            // Modified verifier.
2148            Prio3PrepareShare {
2149                verifiers: Vec::from([100]),
2150                joint_rand_part: Some(Seed([1])),
2151            },
2152            // Modified joint_rand_part.
2153            Prio3PrepareShare {
2154                verifiers: Vec::from([0]),
2155                joint_rand_part: Some(Seed([101])),
2156            },
2157            // Missing joint_rand_part.
2158            Prio3PrepareShare {
2159                verifiers: Vec::from([0]),
2160                joint_rand_part: None,
2161            },
2162        ])
2163    }
2164
2165    #[test]
2166    fn prepare_message_equality_test() {
2167        equality_comparison_test(&[
2168            // Default.
2169            Prio3PrepareMessage {
2170                joint_rand_seed: Some(Seed([0])),
2171            },
2172            // Modified joint_rand_seed.
2173            Prio3PrepareMessage {
2174                joint_rand_seed: Some(Seed([100])),
2175            },
2176            // Missing joint_rand_seed.
2177            Prio3PrepareMessage {
2178                joint_rand_seed: None,
2179            },
2180        ])
2181    }
2182
2183    #[test]
2184    fn prepare_state_equality_test() {
2185        equality_comparison_test(&[
2186            // Default.
2187            Prio3PrepareState {
2188                measurement_share: Share::Leader(Vec::from([0])),
2189                joint_rand_seed: Some(Seed([1])),
2190                agg_id: 2,
2191                verifiers_len: 3,
2192            },
2193            // Modified measurement share.
2194            Prio3PrepareState {
2195                measurement_share: Share::Leader(Vec::from([100])),
2196                joint_rand_seed: Some(Seed([1])),
2197                agg_id: 2,
2198                verifiers_len: 3,
2199            },
2200            // Modified joint_rand_seed.
2201            Prio3PrepareState {
2202                measurement_share: Share::Leader(Vec::from([0])),
2203                joint_rand_seed: Some(Seed([101])),
2204                agg_id: 2,
2205                verifiers_len: 3,
2206            },
2207            // Missing joint_rand_seed.
2208            Prio3PrepareState {
2209                measurement_share: Share::Leader(Vec::from([0])),
2210                joint_rand_seed: None,
2211                agg_id: 2,
2212                verifiers_len: 3,
2213            },
2214            // Modified agg_id.
2215            Prio3PrepareState {
2216                measurement_share: Share::Leader(Vec::from([0])),
2217                joint_rand_seed: Some(Seed([1])),
2218                agg_id: 102,
2219                verifiers_len: 3,
2220            },
2221            // Modified verifier_len.
2222            Prio3PrepareState {
2223                measurement_share: Share::Leader(Vec::from([0])),
2224                joint_rand_seed: Some(Seed([1])),
2225                agg_id: 2,
2226                verifiers_len: 103,
2227            },
2228        ])
2229    }
2230
2231    #[test]
2232    fn test_optimal_chunk_length() {
2233        // nonsense argument, but make sure it doesn't panic.
2234        optimal_chunk_length(0);
2235
2236        // edge cases on either side of power-of-two jumps
2237        assert_eq!(optimal_chunk_length(1), 1);
2238        assert_eq!(optimal_chunk_length(2), 2);
2239        assert_eq!(optimal_chunk_length(3), 1);
2240        assert_eq!(optimal_chunk_length(18), 6);
2241        assert_eq!(optimal_chunk_length(19), 3);
2242
2243        // additional arbitrary test cases
2244        assert_eq!(optimal_chunk_length(40), 6);
2245        assert_eq!(optimal_chunk_length(10_000), 79);
2246        assert_eq!(optimal_chunk_length(100_000), 393);
2247
2248        // confirm that the chunk lengths are truly optimal
2249        for measurement_length in [2, 3, 4, 5, 18, 19, 40] {
2250            let optimal_chunk_length = optimal_chunk_length(measurement_length);
2251            let optimal_proof_length = Histogram::<Field128, ParallelSum<_, _>>::new(
2252                measurement_length,
2253                optimal_chunk_length,
2254            )
2255            .unwrap()
2256            .proof_len();
2257            for chunk_length in 1..=measurement_length {
2258                let proof_length =
2259                    Histogram::<Field128, ParallelSum<_, _>>::new(measurement_length, chunk_length)
2260                        .unwrap()
2261                        .proof_len();
2262                assert!(proof_length >= optimal_proof_length);
2263            }
2264        }
2265    }
2266}