prio 0.10.0

Implementation of the Prio aggregation system core: https://crypto.stanford.edu/prio/
Documentation
// SPDX-License-Identifier: MPL-2.0

//! Implementation of the Prio3 VDAF [[draft-irtf-cfrg-vdaf-03]].
//!
//! **WARNING:** Neither this code nor the cryptographic construction it implements has undergone
//! significant security analysis. Use at your own risk.
//!
//! Prio3 is based on the Prio system desigend by Dan Boneh and Henry Corrigan-Gibbs and presented
//! at NSDI 2017 [[CGB17]]. However, it incorporates a few techniques from Boneh et al., CRYPTO
//! 2019 [[BBCG+19]], that lead to substantial improvements in terms of run time and communication
//! cost.
//!
//! Prio3 is a transformation of a Fully Linear Proof (FLP) system [[draft-irtf-cfrg-vdaf-03]] into
//! a VDAF. The base type, [`Prio3`], supports a wide variety of aggregation functions, some of
//! which are instantiated here:
//!
//! - [`Prio3Aes128Count`] for aggregating a counter (*)
//! - [`Prio3Aes128CountVec`] for aggregating a vector of counters
//! - [`Prio3Aes128Sum`] for copmputing the sum of integers (*)
//! - [`Prio3Aes128Histogram`] for estimating a distribution via a histogram (*)
//!
//! Additional types can be constructed from [`Prio3`] as needed.
//!
//! (*) denotes that the type is specified in [[draft-irtf-cfrg-vdaf-03]].
//!
//! [BBCG+19]: https://ia.cr/2019/188
//! [CGB17]: https://crypto.stanford.edu/prio/
//! [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/

#[cfg(feature = "crypto-dependencies")]
use super::prg::PrgAes128;
use super::{DST_LEN, VERSION};
use crate::codec::{CodecError, Decode, Encode, ParameterizedDecode};
use crate::field::FieldElement;
#[cfg(feature = "crypto-dependencies")]
use crate::field::{Field128, Field64};
#[cfg(feature = "multithreaded")]
use crate::flp::gadgets::ParallelSumMultithreaded;
#[cfg(feature = "crypto-dependencies")]
use crate::flp::gadgets::{BlindPolyEval, ParallelSum, PolyEval};
#[cfg(feature = "crypto-dependencies")]
use crate::flp::types::fixedpoint_l2::compatible_float::CompatibleFloat;
#[cfg(feature = "crypto-dependencies")]
use crate::flp::types::fixedpoint_l2::FixedPointBoundedL2VecSum;
#[cfg(feature = "crypto-dependencies")]
use crate::flp::types::{Average, Count, CountVec, Histogram, Sum};
use crate::flp::Type;
use crate::prng::Prng;
use crate::vdaf::prg::{Prg, RandSource, Seed};
use crate::vdaf::{
    Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, PrepareTransition,
    Share, ShareDecodingParameter, Vdaf, VdafError,
};
#[cfg(feature = "crypto-dependencies")]
use fixed::traits::Fixed;
use std::convert::TryFrom;
use std::fmt::Debug;
use std::io::Cursor;
use std::iter::IntoIterator;
use std::marker::PhantomData;

/// The count type. Each measurement is an integer in `[0,2)` and the aggregate result is the sum.
#[cfg(feature = "crypto-dependencies")]
pub type Prio3Aes128Count = Prio3<Count<Field64>, PrgAes128, 16>;

#[cfg(feature = "crypto-dependencies")]
impl Prio3Aes128Count {
    /// Construct an instance of Prio3Aes128Count with the given number of aggregators.
    pub fn new_aes128_count(num_aggregators: u8) -> Result<Self, VdafError> {
        Prio3::new(num_aggregators, Count::new())
    }
}

/// The count-vector type. Each measurement is a vector of integers in `[0,2)` and the aggregate is
/// the element-wise sum.
#[cfg(feature = "crypto-dependencies")]
pub type Prio3Aes128CountVec =
    Prio3<CountVec<Field128, ParallelSum<Field128, BlindPolyEval<Field128>>>, PrgAes128, 16>;

#[cfg(feature = "crypto-dependencies")]
impl Prio3Aes128CountVec {
    /// Construct an instance of Prio3Aes1238CountVec with the given number of aggregators. `len`
    /// defines the length of each measurement.
    pub fn new_aes128_count_vec(num_aggregators: u8, len: usize) -> Result<Self, VdafError> {
        Prio3::new(num_aggregators, CountVec::new(len))
    }
}

/// Like [`Prio3Aes128CountVec`] except this type uses multithreading to improve sharding and
/// preparation time. Note that the improvement is only noticeable for very large input lengths,
/// e.g., 201 and up. (Your system's mileage may vary.)
#[cfg(feature = "multithreaded")]
#[cfg(feature = "crypto-dependencies")]
#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
pub type Prio3Aes128CountVecMultithreaded = Prio3<
    CountVec<Field128, ParallelSumMultithreaded<Field128, BlindPolyEval<Field128>>>,
    PrgAes128,
    16,
>;

#[cfg(feature = "multithreaded")]
#[cfg(feature = "crypto-dependencies")]
#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
impl Prio3Aes128CountVecMultithreaded {
    /// Construct an instance of Prio3Aes1238CountVecMultithreaded with the given number of
    /// aggregators. `len` defines the length of each measurement.
    pub fn new_aes128_count_vec_multithreaded(
        num_aggregators: u8,
        len: usize,
    ) -> Result<Self, VdafError> {
        Prio3::new(num_aggregators, CountVec::new(len))
    }
}

/// The sum type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the
/// aggregate is the sum.
#[cfg(feature = "crypto-dependencies")]
pub type Prio3Aes128Sum = Prio3<Sum<Field128>, PrgAes128, 16>;

#[cfg(feature = "crypto-dependencies")]
impl Prio3Aes128Sum {
    /// Construct an instance of Prio3Aes128Sum with the given number of aggregators and required
    /// bit length. The bit length must not exceed 64.
    pub fn new_aes128_sum(num_aggregators: u8, bits: u32) -> Result<Self, VdafError> {
        if bits > 64 {
            return Err(VdafError::Uncategorized(format!(
                "bit length ({}) exceeds limit for aggregate type (64)",
                bits
            )));
        }

        Prio3::new(num_aggregators, Sum::new(bits as usize)?)
    }
}

/// The fixed point vector sum type. Each measurement is a vector of fixed point numbers
/// and the aggregate is the sum represented as 64-bit floats. The preparation phase
/// ensures the L2 norm of the input vector is < 1.
///
/// This is useful for aggregating gradients in a federated version of
/// [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent) with
/// [differential privacy](https://en.wikipedia.org/wiki/Differential_privacy),
/// useful, e.g., for [differentially private deep learning](https://arxiv.org/pdf/1607.00133.pdf).
/// The bound on input norms is required for differential privacy. The fixed point representation
/// allows an easy conversion to the integer type used in internal computation, while leaving
/// conversion to the client. The model itself will have floating point parameters, so the output
/// sum has that type as well.
#[cfg(feature = "crypto-dependencies")]
pub type Prio3Aes128FixedPointBoundedL2VecSum<Fx> = Prio3<
    FixedPointBoundedL2VecSum<
        Fx,
        Field128,
        ParallelSum<Field128, PolyEval<Field128>>,
        ParallelSum<Field128, BlindPolyEval<Field128>>,
    >,
    PrgAes128,
    16,
>;

#[cfg(feature = "crypto-dependencies")]
impl<Fx: Fixed + CompatibleFloat<Field128>> Prio3Aes128FixedPointBoundedL2VecSum<Fx> {
    /// Construct an instance of this VDAF with the given number of aggregators and number of
    /// vector entries.
    pub fn new_aes128_fixedpoint_boundedl2_vec_sum(
        num_aggregators: u8,
        entries: usize,
    ) -> Result<Self, VdafError> {
        check_num_aggregators(num_aggregators)?;
        Prio3::new(num_aggregators, FixedPointBoundedL2VecSum::new(entries)?)
    }
}

/// The fixed point vector sum type. Each measurement is a vector of fixed point numbers
/// and the aggregate is the sum represented as 64-bit floats. The verification function
/// ensures the L2 norm of the input vector is < 1.
#[cfg(feature = "multithreaded")]
#[cfg(feature = "crypto-dependencies")]
#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
pub type Prio3Aes128FixedPointBoundedL2VecSumMultithreaded<Fx> = Prio3<
    FixedPointBoundedL2VecSum<
        Fx,
        Field128,
        ParallelSumMultithreaded<Field128, PolyEval<Field128>>,
        ParallelSumMultithreaded<Field128, BlindPolyEval<Field128>>,
    >,
    PrgAes128,
    16,
>;

#[cfg(feature = "multithreaded")]
#[cfg(feature = "crypto-dependencies")]
#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
impl<Fx: Fixed + CompatibleFloat<Field128>> Prio3Aes128FixedPointBoundedL2VecSumMultithreaded<Fx> {
    /// Construct an instance of this VDAF with the given number of aggregators and number of
    /// vector entries.
    pub fn new_aes128_fixedpoint_boundedl2_vec_sum_multithreaded(
        num_aggregators: u8,
        entries: usize,
    ) -> Result<Self, VdafError> {
        check_num_aggregators(num_aggregators)?;
        Prio3::new(num_aggregators, FixedPointBoundedL2VecSum::new(entries)?)
    }
}

/// The histogram type. Each measurement is an unsigned integer and the result is a histogram
/// representation of the distribution. The bucket boundaries are fixed in advance.
#[cfg(feature = "crypto-dependencies")]
pub type Prio3Aes128Histogram = Prio3<Histogram<Field128>, PrgAes128, 16>;

#[cfg(feature = "crypto-dependencies")]
impl Prio3Aes128Histogram {
    /// Constructs an instance of Prio3Aes128Histogram with the given number of aggregators and
    /// desired histogram bucket boundaries.
    pub fn new_aes128_histogram(num_aggregators: u8, buckets: &[u64]) -> Result<Self, VdafError> {
        let buckets = buckets.iter().map(|bucket| *bucket as u128).collect();

        Prio3::new(num_aggregators, Histogram::new(buckets)?)
    }
}

/// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and
/// the aggregate is the arithmetic average.
#[cfg(feature = "crypto-dependencies")]
pub type Prio3Aes128Average = Prio3<Average<Field128>, PrgAes128, 16>;

#[cfg(feature = "crypto-dependencies")]
impl Prio3Aes128Average {
    /// Construct an instance of Prio3Aes128Average with the given number of aggregators and
    /// required bit length. The bit length must not exceed 64.
    pub fn new_aes128_average(num_aggregators: u8, bits: u32) -> Result<Self, VdafError> {
        check_num_aggregators(num_aggregators)?;

        if bits > 64 {
            return Err(VdafError::Uncategorized(format!(
                "bit length ({}) exceeds limit for aggregate type (64)",
                bits
            )));
        }

        Ok(Prio3 {
            num_aggregators,
            typ: Average::new(bits as usize)?,
            phantom: PhantomData,
        })
    }
}

/// The base type for Prio3.
///
/// An instance of Prio3 is determined by:
///
/// - a [`Type`](crate::flp::Type) that defines the set of valid input measurements; and
/// - a [`Prg`](crate::vdaf::prg::Prg) for deriving vectors of field elements from seeds.
///
/// New instances can be defined by aliasing the base type. For example, [`Prio3Aes128Count`] is an
/// alias for `Prio3<Count<Field64>, PrgAes128, 16>`.
///
/// ```
/// use prio::vdaf::{
///     Aggregator, Client, Collector, PrepareTransition,
///     prio3::Prio3,
/// };
/// use rand::prelude::*;
///
/// let num_shares = 2;
/// let vdaf = Prio3::new_aes128_count(num_shares).unwrap();
///
/// let mut out_shares = vec![vec![]; num_shares.into()];
/// let mut rng = thread_rng();
/// let verify_key = rng.gen();
/// let measurements = [0, 1, 1, 1, 0];
/// for measurement in measurements {
///     // Shard
///     let (public_share, input_shares) = vdaf.shard(&measurement).unwrap();
///     let mut nonce = [0; 16];
///     rng.fill(&mut nonce);
///
///     // Prepare
///     let mut prep_states = vec![];
///     let mut prep_shares = vec![];
///     for (agg_id, input_share) in input_shares.iter().enumerate() {
///         let (state, share) = vdaf.prepare_init(
///             &verify_key,
///             agg_id,
///             &(),
///             &nonce,
///             &public_share,
///             input_share
///         ).unwrap();
///         prep_states.push(state);
///         prep_shares.push(share);
///     }
///     let prep_msg = vdaf.prepare_preprocess(prep_shares).unwrap();
///
///     for (agg_id, state) in prep_states.into_iter().enumerate() {
///         let out_share = match vdaf.prepare_step(state, prep_msg.clone()).unwrap() {
///             PrepareTransition::Finish(out_share) => out_share,
///             _ => panic!("unexpected transition"),
///         };
///         out_shares[agg_id].push(out_share);
///     }
/// }
///
/// // Aggregate
/// let agg_shares = out_shares.into_iter()
///     .map(|o| vdaf.aggregate(&(), o).unwrap());
///
/// // Unshard
/// let agg_res = vdaf.unshard(&(), agg_shares, measurements.len()).unwrap();
/// assert_eq!(agg_res, 3);
/// ```
///
/// [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/
#[derive(Clone, Debug)]
pub struct Prio3<T, P, const L: usize>
where
    T: Type,
    P: Prg<L>,
{
    num_aggregators: u8,
    typ: T,
    phantom: PhantomData<P>,
}

impl<T, P, const L: usize> Prio3<T, P, L>
where
    T: Type,
    P: Prg<L>,
{
    /// Construct an instance of this Prio3 VDAF with the given number of aggregators and the
    /// underlying type.
    pub fn new(num_aggregators: u8, typ: T) -> Result<Self, VdafError> {
        check_num_aggregators(num_aggregators)?;
        Ok(Self {
            num_aggregators,
            typ,
            phantom: PhantomData,
        })
    }

    /// The output length of the underlying FLP.
    pub fn output_len(&self) -> usize {
        self.typ.output_len()
    }

    /// The verifier length of the underlying FLP.
    pub fn verifier_len(&self) -> usize {
        self.typ.verifier_len()
    }

    fn derive_joint_randomness<'a>(parts: impl Iterator<Item = &'a Seed<L>>) -> Seed<L> {
        let mut info = [0; VERSION.len() + 5];
        info[..VERSION.len()].copy_from_slice(VERSION);
        info[VERSION.len()..VERSION.len() + 4].copy_from_slice(&Self::ID.to_be_bytes());
        info[VERSION.len() + 4] = 255;
        let mut deriver = P::init(&[0; L]);
        deriver.update(&info);
        for part in parts {
            deriver.update(part.as_ref());
        }
        deriver.into_seed()
    }

    fn shard_with_rand_source(
        &self,
        measurement: &T::Measurement,
        rand_source: RandSource,
    ) -> Result<Vec<Prio3InputShare<T::Field, L>>, VdafError> {
        let mut info = [0; DST_LEN + 1];
        info[..VERSION.len()].copy_from_slice(VERSION);
        info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes());

        let num_aggregators = self.num_aggregators;
        let input = self.typ.encode_measurement(measurement)?;

        // Generate the input shares and compute the joint randomness.
        let mut helper_shares = Vec::with_capacity(num_aggregators as usize - 1);
        let mut helper_joint_rand_parts = if self.typ.joint_rand_len() > 0 {
            Some(Vec::with_capacity(num_aggregators as usize - 1))
        } else {
            None
        };
        let mut leader_input_share = input.clone();
        for agg_id in 1..num_aggregators {
            let helper = HelperShare::from_rand_source(rand_source)?;

            let mut deriver = P::init(helper.joint_rand_param.blind.as_ref());
            info[DST_LEN] = agg_id;
            deriver.update(&info);
            let prng: Prng<T::Field, _> =
                Prng::from_seed_stream(P::seed_stream(&helper.input_share, &info));
            for (x, y) in leader_input_share
                .iter_mut()
                .zip(prng)
                .take(self.typ.input_len())
            {
                *x -= y;
                deriver.update(&y.into());
            }

            if let Some(helper_joint_rand_parts) = helper_joint_rand_parts.as_mut() {
                helper_joint_rand_parts.push(deriver.into_seed());
            }
            helper_shares.push(helper);
        }

        let leader_blind = Seed::from_rand_source(rand_source)?;

        info[DST_LEN] = 0; // ID of the leader
        let mut deriver = P::init(leader_blind.as_ref());
        deriver.update(&info);
        for x in leader_input_share.iter() {
            deriver.update(&(*x).into());
        }

        let leader_joint_rand_seed_part = deriver.into_seed();

        // Compute the joint randomness seed.
        let joint_rand_seed = helper_joint_rand_parts.as_ref().map(|parts| {
            Self::derive_joint_randomness(
                std::iter::once(&leader_joint_rand_seed_part).chain(parts.iter()),
            )
        });

        // Run the proof-generation algorithm.
        let domain_separation_tag = &info[..DST_LEN];
        let joint_rand: Vec<T::Field> = joint_rand_seed
            .map(|joint_rand_seed| {
                let prng: Prng<T::Field, _> =
                    Prng::from_seed_stream(P::seed_stream(&joint_rand_seed, domain_separation_tag));
                prng.take(self.typ.joint_rand_len()).collect()
            })
            .unwrap_or_default();
        let prng: Prng<T::Field, _> = Prng::from_seed_stream(P::seed_stream(
            &Seed::from_rand_source(rand_source)?,
            domain_separation_tag,
        ));
        let prove_rand: Vec<T::Field> = prng.take(self.typ.prove_rand_len()).collect();
        let mut leader_proof_share = self.typ.prove(&input, &prove_rand, &joint_rand)?;

        // Generate the proof shares and distribute the joint randomness seed hints.
        for (j, helper) in helper_shares.iter_mut().enumerate() {
            info[DST_LEN] = j as u8 + 1;
            let prng: Prng<T::Field, _> =
                Prng::from_seed_stream(P::seed_stream(&helper.proof_share, &info));
            for (x, y) in leader_proof_share
                .iter_mut()
                .zip(prng)
                .take(self.typ.proof_len())
            {
                *x -= y;
            }

            if let Some(helper_joint_rand_parts) = helper_joint_rand_parts.as_ref() {
                let mut hint = Vec::with_capacity(num_aggregators as usize - 1);
                hint.push(leader_joint_rand_seed_part.clone());
                hint.extend(helper_joint_rand_parts[..j].iter().cloned());
                hint.extend(helper_joint_rand_parts[j + 1..].iter().cloned());
                helper.joint_rand_param.seed_hint = hint;
            }
        }

        let leader_joint_rand_param = if self.typ.joint_rand_len() > 0 {
            Some(JointRandParam {
                seed_hint: helper_joint_rand_parts.unwrap_or_default(),
                blind: leader_blind,
            })
        } else {
            None
        };

        // Prep the output messages.
        let mut out = Vec::with_capacity(num_aggregators as usize);
        out.push(Prio3InputShare {
            input_share: Share::Leader(leader_input_share),
            proof_share: Share::Leader(leader_proof_share),
            joint_rand_param: leader_joint_rand_param,
        });

        for helper in helper_shares.into_iter() {
            let helper_joint_rand_param = if self.typ.joint_rand_len() > 0 {
                Some(helper.joint_rand_param)
            } else {
                None
            };

            out.push(Prio3InputShare {
                input_share: Share::Helper(helper.input_share),
                proof_share: Share::Helper(helper.proof_share),
                joint_rand_param: helper_joint_rand_param,
            });
        }

        Ok(out)
    }

    /// Shard measurement with constant randomness of repeated bytes.
    /// This method is not secure. It is used for running test vectors for Prio3.
    #[cfg(feature = "test-util")]
    pub fn test_vec_shard(
        &self,
        measurement: &T::Measurement,
    ) -> Result<Vec<Prio3InputShare<T::Field, L>>, VdafError> {
        self.shard_with_rand_source(measurement, |buf| {
            buf.fill(1);
            Ok(())
        })
    }

    fn role_try_from(&self, agg_id: usize) -> Result<u8, VdafError> {
        if agg_id >= self.num_aggregators as usize {
            return Err(VdafError::Uncategorized("unexpected aggregator id".into()));
        }
        Ok(u8::try_from(agg_id).unwrap())
    }
}

impl<T, P, const L: usize> Vdaf for Prio3<T, P, L>
where
    T: Type,
    P: Prg<L>,
{
    const ID: u32 = T::ID;
    type Measurement = T::Measurement;
    type AggregateResult = T::AggregateResult;
    type AggregationParam = ();
    type PublicShare = ();
    type InputShare = Prio3InputShare<T::Field, L>;
    type OutputShare = OutputShare<T::Field>;
    type AggregateShare = AggregateShare<T::Field>;

    fn num_aggregators(&self) -> usize {
        self.num_aggregators as usize
    }
}

/// Message sent by the [`Client`](crate::vdaf::Client) to each
/// [`Aggregator`](crate::vdaf::Aggregator) during the Sharding phase.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Prio3InputShare<F, const L: usize> {
    /// The input share.
    input_share: Share<F, L>,

    /// The proof share.
    proof_share: Share<F, L>,

    /// Parameters used by the Aggregator to compute the joint randomness. This field is optional
    /// because not every [`Type`](`crate::flp::Type`) requires joint randomness.
    joint_rand_param: Option<JointRandParam<L>>,
}

impl<F: FieldElement, const L: usize> Encode for Prio3InputShare<F, L> {
    fn encode(&self, bytes: &mut Vec<u8>) {
        if matches!(
            (&self.input_share, &self.proof_share),
            (Share::Leader(_), Share::Helper(_)) | (Share::Helper(_), Share::Leader(_))
        ) {
            panic!("tried to encode input share with ambiguous encoding")
        }

        self.input_share.encode(bytes);
        self.proof_share.encode(bytes);
        if let Some(ref param) = self.joint_rand_param {
            param.blind.encode(bytes);
            for part in param.seed_hint.iter() {
                part.encode(bytes);
            }
        }
    }
}

impl<'a, T, P, const L: usize> ParameterizedDecode<(&'a Prio3<T, P, L>, usize)>
    for Prio3InputShare<T::Field, L>
where
    T: Type,
    P: Prg<L>,
{
    fn decode_with_param(
        (prio3, agg_id): &(&'a Prio3<T, P, L>, usize),
        bytes: &mut Cursor<&[u8]>,
    ) -> Result<Self, CodecError> {
        let agg_id = prio3
            .role_try_from(*agg_id)
            .map_err(|e| CodecError::Other(Box::new(e)))?;
        let (input_decoder, proof_decoder) = if agg_id == 0 {
            (
                ShareDecodingParameter::Leader(prio3.typ.input_len()),
                ShareDecodingParameter::Leader(prio3.typ.proof_len()),
            )
        } else {
            (
                ShareDecodingParameter::Helper,
                ShareDecodingParameter::Helper,
            )
        };

        let input_share = Share::decode_with_param(&input_decoder, bytes)?;
        let proof_share = Share::decode_with_param(&proof_decoder, bytes)?;
        let joint_rand_param = if prio3.typ.joint_rand_len() > 0 {
            let num_aggregators = prio3.num_aggregators();
            let blind = Seed::decode(bytes)?;
            let seed_hint = std::iter::repeat_with(|| Seed::decode(bytes))
                .take(num_aggregators - 1)
                .collect::<Result<Vec<_>, _>>()?;
            Some(JointRandParam { blind, seed_hint })
        } else {
            None
        };

        Ok(Prio3InputShare {
            input_share,
            proof_share,
            joint_rand_param,
        })
    }
}

#[derive(Clone, Debug, Eq, PartialEq)]
/// Message broadcast by each [`Aggregator`](crate::vdaf::Aggregator) in each round of the
/// Preparation phase.
pub struct Prio3PrepareShare<F, const L: usize> {
    /// A share of the FLP verifier message. (See [`Type`](crate::flp::Type).)
    verifier: Vec<F>,

    /// A part of the joint randomness seed.
    joint_rand_part: Option<Seed<L>>,
}

impl<F: FieldElement, const L: usize> Encode for Prio3PrepareShare<F, L> {
    fn encode(&self, bytes: &mut Vec<u8>) {
        for x in &self.verifier {
            x.encode(bytes);
        }
        if let Some(ref seed) = self.joint_rand_part {
            seed.encode(bytes);
        }
    }
}

impl<F: FieldElement, const L: usize> ParameterizedDecode<Prio3PrepareState<F, L>>
    for Prio3PrepareShare<F, L>
{
    fn decode_with_param(
        decoding_parameter: &Prio3PrepareState<F, L>,
        bytes: &mut Cursor<&[u8]>,
    ) -> Result<Self, CodecError> {
        let mut verifier = Vec::with_capacity(decoding_parameter.verifier_len);
        for _ in 0..decoding_parameter.verifier_len {
            verifier.push(F::decode(bytes)?);
        }

        let joint_rand_part = if decoding_parameter.joint_rand_seed.is_some() {
            Some(Seed::decode(bytes)?)
        } else {
            None
        };

        Ok(Prio3PrepareShare {
            verifier,
            joint_rand_part,
        })
    }
}

#[derive(Clone, Debug, Eq, PartialEq)]
/// Result of combining a round of [`Prio3PrepareShare`] messages.
pub struct Prio3PrepareMessage<const L: usize> {
    /// The joint randomness seed computed by the Aggregators.
    joint_rand_seed: Option<Seed<L>>,
}

impl<const L: usize> Encode for Prio3PrepareMessage<L> {
    fn encode(&self, bytes: &mut Vec<u8>) {
        if let Some(ref seed) = self.joint_rand_seed {
            seed.encode(bytes);
        }
    }
}

impl<F: FieldElement, const L: usize> ParameterizedDecode<Prio3PrepareState<F, L>>
    for Prio3PrepareMessage<L>
{
    fn decode_with_param(
        decoding_parameter: &Prio3PrepareState<F, L>,
        bytes: &mut Cursor<&[u8]>,
    ) -> Result<Self, CodecError> {
        let joint_rand_seed = if decoding_parameter.joint_rand_seed.is_some() {
            Some(Seed::decode(bytes)?)
        } else {
            None
        };

        Ok(Prio3PrepareMessage { joint_rand_seed })
    }
}

impl<T, P, const L: usize> Client for Prio3<T, P, L>
where
    T: Type,
    P: Prg<L>,
{
    #[allow(clippy::type_complexity)]
    fn shard(
        &self,
        measurement: &T::Measurement,
    ) -> Result<((), Vec<Prio3InputShare<T::Field, L>>), VdafError> {
        self.shard_with_rand_source(measurement, getrandom::getrandom)
            .map(|input_shares| ((), input_shares))
    }
}

/// State of each [`Aggregator`](crate::vdaf::Aggregator) during the Preparation phase.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Prio3PrepareState<F, const L: usize> {
    input_share: Share<F, L>,
    joint_rand_seed: Option<Seed<L>>,
    agg_id: u8,
    verifier_len: usize,
}

impl<F: FieldElement, const L: usize> Encode for Prio3PrepareState<F, L> {
    /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed.
    fn encode(&self, bytes: &mut Vec<u8>) {
        self.input_share.encode(bytes);
        if let Some(ref seed) = self.joint_rand_seed {
            seed.encode(bytes);
        }
    }
}

impl<'a, T, P, const L: usize> ParameterizedDecode<(&'a Prio3<T, P, L>, usize)>
    for Prio3PrepareState<T::Field, L>
where
    T: Type,
    P: Prg<L>,
{
    fn decode_with_param(
        (prio3, agg_id): &(&'a Prio3<T, P, L>, usize),
        bytes: &mut Cursor<&[u8]>,
    ) -> Result<Self, CodecError> {
        let agg_id = prio3
            .role_try_from(*agg_id)
            .map_err(|e| CodecError::Other(Box::new(e)))?;

        let share_decoder = if agg_id == 0 {
            ShareDecodingParameter::Leader(prio3.typ.input_len())
        } else {
            ShareDecodingParameter::Helper
        };
        let input_share = Share::decode_with_param(&share_decoder, bytes)?;

        let joint_rand_seed = if prio3.typ.joint_rand_len() > 0 {
            Some(Seed::decode(bytes)?)
        } else {
            None
        };

        Ok(Self {
            input_share,
            joint_rand_seed,
            agg_id,
            verifier_len: prio3.typ.verifier_len(),
        })
    }
}

impl<T, P, const L: usize> Aggregator<L> for Prio3<T, P, L>
where
    T: Type,
    P: Prg<L>,
{
    type PrepareState = Prio3PrepareState<T::Field, L>;
    type PrepareShare = Prio3PrepareShare<T::Field, L>;
    type PrepareMessage = Prio3PrepareMessage<L>;

    /// Begins the Prep process with the other aggregators. The result of this process is
    /// the aggregator's output share.
    #[allow(clippy::type_complexity)]
    fn prepare_init(
        &self,
        verify_key: &[u8; L],
        agg_id: usize,
        _agg_param: &(),
        nonce: &[u8],
        _public_share: &(),
        msg: &Prio3InputShare<T::Field, L>,
    ) -> Result<
        (
            Prio3PrepareState<T::Field, L>,
            Prio3PrepareShare<T::Field, L>,
        ),
        VdafError,
    > {
        let agg_id = self.role_try_from(agg_id)?;
        let mut info = [0; DST_LEN + 1];
        info[..VERSION.len()].copy_from_slice(VERSION);
        info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes());
        info[DST_LEN] = agg_id;
        let domain_separation_tag = &info[..DST_LEN];

        let mut deriver = P::init(verify_key);
        deriver.update(domain_separation_tag);
        deriver.update(&[255]);
        deriver.update(nonce);
        let query_rand_prng = Prng::from_seed_stream(deriver.into_seed_stream());

        // Create a reference to the (expanded) input share.
        let expanded_input_share: Option<Vec<T::Field>> = match msg.input_share {
            Share::Leader(_) => None,
            Share::Helper(ref seed) => {
                let prng = Prng::from_seed_stream(P::seed_stream(seed, &info));
                Some(prng.take(self.typ.input_len()).collect())
            }
        };
        let input_share = match msg.input_share {
            Share::Leader(ref data) => data,
            Share::Helper(_) => expanded_input_share.as_ref().unwrap(),
        };

        // Create a reference to the (expanded) proof share.
        let expanded_proof_share: Option<Vec<T::Field>> = match msg.proof_share {
            Share::Leader(_) => None,
            Share::Helper(ref seed) => {
                let prng = Prng::from_seed_stream(P::seed_stream(seed, &info));
                Some(prng.take(self.typ.proof_len()).collect())
            }
        };
        let proof_share = match msg.proof_share {
            Share::Leader(ref data) => data,
            Share::Helper(_) => expanded_proof_share.as_ref().unwrap(),
        };

        // Compute the joint randomness.
        let (joint_rand_seed, joint_rand_seed_part, joint_rand) = if self.typ.joint_rand_len() > 0 {
            let mut deriver = P::init(msg.joint_rand_param.as_ref().unwrap().blind.as_ref());
            deriver.update(&info);
            for x in input_share {
                deriver.update(&(*x).into());
            }
            let joint_rand_seed_part = deriver.into_seed();

            let hints = &msg.joint_rand_param.as_ref().unwrap().seed_hint;
            let joint_rand_seed = Self::derive_joint_randomness(
                hints[..agg_id as usize]
                    .iter()
                    .chain(std::iter::once(&joint_rand_seed_part))
                    .chain(hints[agg_id as usize..].iter()),
            );

            let prng: Prng<T::Field, _> =
                Prng::from_seed_stream(P::seed_stream(&joint_rand_seed, domain_separation_tag));
            (
                Some(joint_rand_seed),
                Some(joint_rand_seed_part),
                prng.take(self.typ.joint_rand_len()).collect(),
            )
        } else {
            (None, None, Vec::new())
        };

        // Compute the query randomness.
        let query_rand: Vec<T::Field> = query_rand_prng.take(self.typ.query_rand_len()).collect();

        // Run the query-generation algorithm.
        let verifier_share = self.typ.query(
            input_share,
            proof_share,
            &query_rand,
            &joint_rand,
            self.num_aggregators as usize,
        )?;

        Ok((
            Prio3PrepareState {
                input_share: msg.input_share.clone(),
                joint_rand_seed,
                agg_id,
                verifier_len: verifier_share.len(),
            },
            Prio3PrepareShare {
                verifier: verifier_share,
                joint_rand_part: joint_rand_seed_part,
            },
        ))
    }

    fn prepare_preprocess<M: IntoIterator<Item = Prio3PrepareShare<T::Field, L>>>(
        &self,
        inputs: M,
    ) -> Result<Prio3PrepareMessage<L>, VdafError> {
        let mut verifier = vec![T::Field::zero(); self.typ.verifier_len()];
        let mut joint_rand_parts = Vec::with_capacity(self.num_aggregators());
        let mut count = 0;
        for share in inputs.into_iter() {
            count += 1;

            if share.verifier.len() != verifier.len() {
                return Err(VdafError::Uncategorized(format!(
                    "unexpected verifier share length: got {}; want {}",
                    share.verifier.len(),
                    verifier.len(),
                )));
            }

            if self.typ.joint_rand_len() > 0 {
                let joint_rand_seed_part = share.joint_rand_part.unwrap();
                joint_rand_parts.push(joint_rand_seed_part);
            }

            for (x, y) in verifier.iter_mut().zip(share.verifier) {
                *x += y;
            }
        }

        if count != self.num_aggregators {
            return Err(VdafError::Uncategorized(format!(
                "unexpected message count: got {}; want {}",
                count, self.num_aggregators,
            )));
        }

        // Check the proof verifier.
        match self.typ.decide(&verifier) {
            Ok(true) => (),
            Ok(false) => {
                return Err(VdafError::Uncategorized(
                    "proof verifier check failed".into(),
                ))
            }
            Err(err) => return Err(VdafError::from(err)),
        };

        let joint_rand_seed = if self.typ.joint_rand_len() > 0 {
            Some(Self::derive_joint_randomness(joint_rand_parts.iter()))
        } else {
            None
        };

        Ok(Prio3PrepareMessage { joint_rand_seed })
    }

    fn prepare_step(
        &self,
        step: Prio3PrepareState<T::Field, L>,
        msg: Prio3PrepareMessage<L>,
    ) -> Result<PrepareTransition<Self, L>, VdafError> {
        if self.typ.joint_rand_len() > 0 {
            // Check that the joint randomness was correct.
            if step.joint_rand_seed.as_ref().unwrap() != msg.joint_rand_seed.as_ref().unwrap() {
                return Err(VdafError::Uncategorized(
                    "joint randomness mismatch".to_string(),
                ));
            }
        }

        // Compute the output share.
        let input_share = match step.input_share {
            Share::Leader(data) => data,
            Share::Helper(seed) => {
                let mut info = [0; DST_LEN + 1];
                info[..VERSION.len()].copy_from_slice(VERSION);
                info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes());
                info[DST_LEN] = step.agg_id;
                let prng = Prng::from_seed_stream(P::seed_stream(&seed, &info));
                prng.take(self.typ.input_len()).collect()
            }
        };

        let output_share = match self.typ.truncate(input_share) {
            Ok(data) => OutputShare(data),
            Err(err) => {
                return Err(VdafError::from(err));
            }
        };

        Ok(PrepareTransition::Finish(output_share))
    }

    /// Aggregates a sequence of output shares into an aggregate share.
    fn aggregate<It: IntoIterator<Item = OutputShare<T::Field>>>(
        &self,
        _agg_param: &(),
        output_shares: It,
    ) -> Result<AggregateShare<T::Field>, VdafError> {
        let mut agg_share = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]);
        for output_share in output_shares.into_iter() {
            agg_share.accumulate(&output_share)?;
        }

        Ok(agg_share)
    }
}

impl<T, P, const L: usize> Collector for Prio3<T, P, L>
where
    T: Type,
    P: Prg<L>,
{
    /// Combines aggregate shares into the aggregate result.
    fn unshard<It: IntoIterator<Item = AggregateShare<T::Field>>>(
        &self,
        _agg_param: &(),
        agg_shares: It,
        num_measurements: usize,
    ) -> Result<T::AggregateResult, VdafError> {
        let mut agg = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]);
        for agg_share in agg_shares.into_iter() {
            agg.merge(&agg_share)?;
        }

        Ok(self.typ.decode_result(&agg.0, num_measurements)?)
    }
}

#[derive(Clone, Debug, Eq, PartialEq)]
struct JointRandParam<const L: usize> {
    /// The joint randomness seed parts corresponding to the other Aggregators' shares.
    seed_hint: Vec<Seed<L>>,

    /// The blinding factor, used to derive the aggregator's joint randomness seed part.
    blind: Seed<L>,
}

#[derive(Clone)]
struct HelperShare<const L: usize> {
    input_share: Seed<L>,
    proof_share: Seed<L>,
    joint_rand_param: JointRandParam<L>,
}

impl<const L: usize> HelperShare<L> {
    fn from_rand_source(rand_source: RandSource) -> Result<Self, VdafError> {
        Ok(HelperShare {
            input_share: Seed::from_rand_source(rand_source)?,
            proof_share: Seed::from_rand_source(rand_source)?,
            joint_rand_param: JointRandParam {
                seed_hint: Vec::new(),
                blind: Seed::from_rand_source(rand_source)?,
            },
        })
    }
}

fn check_num_aggregators(num_aggregators: u8) -> Result<(), VdafError> {
    if num_aggregators == 0 {
        return Err(VdafError::Uncategorized(format!(
            "at least one aggregator is required; got {}",
            num_aggregators
        )));
    } else if num_aggregators > 254 {
        return Err(VdafError::Uncategorized(format!(
            "number of aggregators must not exceed 254; got {}",
            num_aggregators
        )));
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::flp::gadgets::ParallelSumGadget;
    use crate::vdaf::{run_vdaf, run_vdaf_prepare};
    use assert_matches::assert_matches;
    use fixed::types::extra::{U15, U31, U63};
    use fixed::{FixedI16, FixedI32, FixedI64};
    use fixed_macro::fixed;
    use rand::prelude::*;

    #[test]
    fn test_prio3_count() {
        let prio3 = Prio3::new_aes128_count(2).unwrap();

        assert_eq!(run_vdaf(&prio3, &(), [1, 0, 0, 1, 1]).unwrap(), 3);

        let mut verify_key = [0; 16];
        thread_rng().fill(&mut verify_key[..]);
        let nonce = b"This is a good nonce.";

        let (public_share, input_shares) = prio3.shard(&0).unwrap();
        run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares).unwrap();

        let (public_share, input_shares) = prio3.shard(&1).unwrap();
        run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares).unwrap();

        test_prepare_state_serialization(&prio3, &1).unwrap();

        let prio3_extra_helper = Prio3::new_aes128_count(3).unwrap();
        assert_eq!(
            run_vdaf(&prio3_extra_helper, &(), [1, 0, 0, 1, 1]).unwrap(),
            3,
        );
    }

    #[test]
    fn test_prio3_sum() {
        let prio3 = Prio3::new_aes128_sum(3, 16).unwrap();

        assert_eq!(
            run_vdaf(&prio3, &(), [0, (1 << 16) - 1, 0, 1, 1]).unwrap(),
            (1 << 16) + 1
        );

        let mut verify_key = [0; 16];
        thread_rng().fill(&mut verify_key[..]);
        let nonce = b"This is a good nonce.";

        let (public_share, mut input_shares) = prio3.shard(&1).unwrap();
        input_shares[0].joint_rand_param.as_mut().unwrap().blind.0[0] ^= 255;
        let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
        assert_matches!(result, Err(VdafError::Uncategorized(_)));

        let (public_share, mut input_shares) = prio3.shard(&1).unwrap();
        input_shares[0].joint_rand_param.as_mut().unwrap().seed_hint[0].0[0] ^= 255;
        let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
        assert_matches!(result, Err(VdafError::Uncategorized(_)));

        let (public_share, mut input_shares) = prio3.shard(&1).unwrap();
        assert_matches!(input_shares[0].input_share, Share::Leader(ref mut data) => {
            data[0] += Field128::one();
        });
        let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
        assert_matches!(result, Err(VdafError::Uncategorized(_)));

        let (public_share, mut input_shares) = prio3.shard(&1).unwrap();
        assert_matches!(input_shares[0].proof_share, Share::Leader(ref mut data) => {
                data[0] += Field128::one();
        });
        let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
        assert_matches!(result, Err(VdafError::Uncategorized(_)));

        test_prepare_state_serialization(&prio3, &1).unwrap();
    }

    #[test]
    fn test_prio3_countvec() {
        let prio3 = Prio3::new_aes128_count_vec(2, 20).unwrap();
        assert_eq!(
            run_vdaf(
                &prio3,
                &(),
                [vec![
                    0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,
                ]]
            )
            .unwrap(),
            vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,]
        );
    }

    #[test]
    #[cfg(feature = "multithreaded")]
    fn test_prio3_countvec_multithreaded() {
        let prio3 = Prio3::new_aes128_count_vec_multithreaded(2, 20).unwrap();
        assert_eq!(
            run_vdaf(
                &prio3,
                &(),
                [vec![
                    0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,
                ]]
            )
            .unwrap(),
            vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,]
        );
    }

    #[test]
    fn test_prio3_bounded_fpvec_sum() {
        type P<Fx> = Prio3Aes128FixedPointBoundedL2VecSum<Fx>;
        let ctor_16 = P::<FixedI16<U15>>::new_aes128_fixedpoint_boundedl2_vec_sum;
        let ctor_32 = P::<FixedI32<U31>>::new_aes128_fixedpoint_boundedl2_vec_sum;
        let ctor_64 = P::<FixedI64<U63>>::new_aes128_fixedpoint_boundedl2_vec_sum;

        #[cfg(feature = "multithreaded")]
        type PM<Fx> = Prio3Aes128FixedPointBoundedL2VecSumMultithreaded<Fx>;
        #[cfg(feature = "multithreaded")]
        let ctor_mt_16 = PM::<FixedI16<U15>>::new_aes128_fixedpoint_boundedl2_vec_sum_multithreaded;
        #[cfg(feature = "multithreaded")]
        let ctor_mt_32 = PM::<FixedI32<U31>>::new_aes128_fixedpoint_boundedl2_vec_sum_multithreaded;
        #[cfg(feature = "multithreaded")]
        let ctor_mt_64 = PM::<FixedI64<U63>>::new_aes128_fixedpoint_boundedl2_vec_sum_multithreaded;

        {
            // 16 bit fixedpoint
            let fp16_4_inv = fixed!(0.25: I1F15);
            let fp16_8_inv = fixed!(0.125: I1F15);
            let fp16_16_inv = fixed!(0.0625: I1F15);

            // two aggregators, three entries per vector.
            {
                let prio3_16 = ctor_16(2, 3).unwrap();
                test_fixed(fp16_4_inv, fp16_8_inv, fp16_16_inv, prio3_16);
            }

            #[cfg(feature = "multithreaded")]
            {
                let prio3_16_mt = ctor_mt_16(2, 3).unwrap();
                test_fixed(fp16_4_inv, fp16_8_inv, fp16_16_inv, prio3_16_mt);
            }
        }

        {
            // 32 bit fixedpoint
            let fp32_4_inv = fixed!(0.25: I1F31);
            let fp32_8_inv = fixed!(0.125: I1F31);
            let fp32_16_inv = fixed!(0.0625: I1F31);

            {
                let prio3_32 = ctor_32(2, 3).unwrap();
                test_fixed(fp32_4_inv, fp32_8_inv, fp32_16_inv, prio3_32);
            }

            #[cfg(feature = "multithreaded")]
            {
                let prio3_32_mt = ctor_mt_32(2, 3).unwrap();
                test_fixed(fp32_4_inv, fp32_8_inv, fp32_16_inv, prio3_32_mt);
            }
        }

        {
            // 64 bit fixedpoint
            let fp64_4_inv = fixed!(0.25: I1F63);
            let fp64_8_inv = fixed!(0.125: I1F63);
            let fp64_16_inv = fixed!(0.0625: I1F63);

            {
                let prio3_64 = ctor_64(2, 3).unwrap();
                test_fixed(fp64_4_inv, fp64_8_inv, fp64_16_inv, prio3_64);
            }

            #[cfg(feature = "multithreaded")]
            {
                let prio3_64_mt = ctor_mt_64(2, 3).unwrap();
                test_fixed(fp64_4_inv, fp64_8_inv, fp64_16_inv, prio3_64_mt);
            }
        }

        fn test_fixed<Fx, PE, BPE>(
            fp_4_inv: Fx,
            fp_8_inv: Fx,
            fp_16_inv: Fx,
            prio3: Prio3<FixedPointBoundedL2VecSum<Fx, Field128, PE, BPE>, PrgAes128, 16>,
        ) where
            Fx: Fixed + CompatibleFloat<Field128> + std::ops::Neg<Output = Fx>,
            PE: Eq + ParallelSumGadget<Field128, PolyEval<Field128>> + Clone + 'static,
            BPE: Eq + ParallelSumGadget<Field128, BlindPolyEval<Field128>> + Clone + 'static,
        {
            let fp_vec1 = vec![fp_4_inv, fp_8_inv, fp_16_inv];
            let fp_vec2 = vec![fp_4_inv, fp_8_inv, fp_16_inv];

            let fp_vec3 = vec![-fp_4_inv, -fp_8_inv, -fp_16_inv];
            let fp_vec4 = vec![-fp_4_inv, -fp_8_inv, -fp_16_inv];

            let fp_vec5 = vec![fp_4_inv, -fp_8_inv, -fp_16_inv];
            let fp_vec6 = vec![fp_4_inv, fp_8_inv, fp_16_inv];

            // positive entries
            let fp_list = [fp_vec1, fp_vec2];
            assert_eq!(
                run_vdaf(&prio3, &(), fp_list).unwrap(),
                vec!(0.5, 0.25, 0.125),
            );

            // negative entries
            let fp_list2 = [fp_vec3, fp_vec4];
            assert_eq!(
                run_vdaf(&prio3, &(), fp_list2).unwrap(),
                vec!(-0.5, -0.25, -0.125),
            );

            // both
            let fp_list3 = [fp_vec5, fp_vec6];
            assert_eq!(
                run_vdaf(&prio3, &(), fp_list3).unwrap(),
                vec!(0.5, 0.0, 0.0),
            );

            let mut verify_key = [0; 16];
            thread_rng().fill(&mut verify_key[..]);
            let nonce = b"This is a good nonce.";

            let (public_share, mut input_shares) =
                prio3.shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv]).unwrap();
            input_shares[0].joint_rand_param.as_mut().unwrap().blind.0[0] ^= 255;
            let result =
                run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
            assert_matches!(result, Err(VdafError::Uncategorized(_)));

            let (public_share, mut input_shares) =
                prio3.shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv]).unwrap();
            input_shares[0].joint_rand_param.as_mut().unwrap().seed_hint[0].0[0] ^= 255;
            let result =
                run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
            assert_matches!(result, Err(VdafError::Uncategorized(_)));

            let (public_share, mut input_shares) =
                prio3.shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv]).unwrap();
            assert_matches!(input_shares[0].input_share, Share::Leader(ref mut data) => {
                data[0] += Field128::one();
            });
            let result =
                run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
            assert_matches!(result, Err(VdafError::Uncategorized(_)));

            let (public_share, mut input_shares) =
                prio3.shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv]).unwrap();
            assert_matches!(input_shares[0].proof_share, Share::Leader(ref mut data) => {
                    data[0] += Field128::one();
            });
            let result =
                run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
            assert_matches!(result, Err(VdafError::Uncategorized(_)));

            test_prepare_state_serialization(&prio3, &vec![fp_4_inv, fp_8_inv, fp_16_inv]).unwrap();
        }
    }

    #[test]
    fn test_prio3_histogram() {
        let prio3 = Prio3::new_aes128_histogram(2, &[0, 10, 20]).unwrap();

        assert_eq!(
            run_vdaf(&prio3, &(), [0, 10, 20, 9999]).unwrap(),
            vec![1, 1, 1, 1]
        );
        assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]);
        assert_eq!(run_vdaf(&prio3, &(), [5]).unwrap(), vec![0, 1, 0, 0]);
        assert_eq!(run_vdaf(&prio3, &(), [10]).unwrap(), vec![0, 1, 0, 0]);
        assert_eq!(run_vdaf(&prio3, &(), [15]).unwrap(), vec![0, 0, 1, 0]);
        assert_eq!(run_vdaf(&prio3, &(), [20]).unwrap(), vec![0, 0, 1, 0]);
        assert_eq!(run_vdaf(&prio3, &(), [25]).unwrap(), vec![0, 0, 0, 1]);
        test_prepare_state_serialization(&prio3, &23).unwrap();
    }

    #[test]
    fn test_prio3_average() {
        let prio3 = Prio3::new_aes128_average(2, 64).unwrap();

        assert_eq!(run_vdaf(&prio3, &(), [17, 8]).unwrap(), 12.5f64);
        assert_eq!(run_vdaf(&prio3, &(), [1, 1, 1, 1]).unwrap(), 1f64);
        assert_eq!(run_vdaf(&prio3, &(), [0, 0, 0, 1]).unwrap(), 0.25f64);
        assert_eq!(
            run_vdaf(&prio3, &(), [1, 11, 111, 1111, 3, 8]).unwrap(),
            207.5f64
        );
    }

    #[test]
    fn test_prio3_input_share() {
        let prio3 = Prio3::new_aes128_sum(5, 16).unwrap();
        let (_public_share, input_shares) = prio3.shard(&1).unwrap();

        // Check that seed shares are distinct.
        for (i, x) in input_shares.iter().enumerate() {
            for (j, y) in input_shares.iter().enumerate() {
                if i != j {
                    if let (Share::Helper(left), Share::Helper(right)) =
                        (&x.input_share, &y.input_share)
                    {
                        assert_ne!(left, right);
                    }

                    if let (Share::Helper(left), Share::Helper(right)) =
                        (&x.proof_share, &y.proof_share)
                    {
                        assert_ne!(left, right);
                    }

                    assert_ne!(x.joint_rand_param, y.joint_rand_param);
                }
            }
        }
    }

    fn test_prepare_state_serialization<T, P, const L: usize>(
        prio3: &Prio3<T, P, L>,
        measurement: &T::Measurement,
    ) -> Result<(), VdafError>
    where
        T: Type,
        P: Prg<L>,
    {
        let mut verify_key = [0; L];
        thread_rng().fill(&mut verify_key[..]);
        let (public_share, input_shares) = prio3.shard(measurement)?;
        for (agg_id, input_share) in input_shares.iter().enumerate() {
            let (want, _msg) =
                prio3.prepare_init(&verify_key, agg_id, &(), &[], &public_share, input_share)?;
            let got =
                Prio3PrepareState::get_decoded_with_param(&(prio3, agg_id), &want.get_encoded())
                    .expect("failed to decode prepare step");
            assert_eq!(got, want);
        }
        Ok(())
    }
}