#[cfg(feature = "crypto-dependencies")]
use super::prg::PrgAes128;
use super::{DST_LEN, VERSION};
use crate::codec::{CodecError, Decode, Encode, ParameterizedDecode};
use crate::field::FieldElement;
#[cfg(feature = "crypto-dependencies")]
use crate::field::{Field128, Field64};
#[cfg(feature = "multithreaded")]
use crate::flp::gadgets::ParallelSumMultithreaded;
#[cfg(feature = "crypto-dependencies")]
use crate::flp::gadgets::{BlindPolyEval, ParallelSum, PolyEval};
#[cfg(feature = "crypto-dependencies")]
use crate::flp::types::fixedpoint_l2::compatible_float::CompatibleFloat;
#[cfg(feature = "crypto-dependencies")]
use crate::flp::types::fixedpoint_l2::FixedPointBoundedL2VecSum;
#[cfg(feature = "crypto-dependencies")]
use crate::flp::types::{Average, Count, CountVec, Histogram, Sum};
use crate::flp::Type;
use crate::prng::Prng;
use crate::vdaf::prg::{Prg, RandSource, Seed};
use crate::vdaf::{
Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, PrepareTransition,
Share, ShareDecodingParameter, Vdaf, VdafError,
};
#[cfg(feature = "crypto-dependencies")]
use fixed::traits::Fixed;
use std::convert::TryFrom;
use std::fmt::Debug;
use std::io::Cursor;
use std::iter::IntoIterator;
use std::marker::PhantomData;
#[cfg(feature = "crypto-dependencies")]
pub type Prio3Aes128Count = Prio3<Count<Field64>, PrgAes128, 16>;
#[cfg(feature = "crypto-dependencies")]
impl Prio3Aes128Count {
pub fn new_aes128_count(num_aggregators: u8) -> Result<Self, VdafError> {
Prio3::new(num_aggregators, Count::new())
}
}
#[cfg(feature = "crypto-dependencies")]
pub type Prio3Aes128CountVec =
Prio3<CountVec<Field128, ParallelSum<Field128, BlindPolyEval<Field128>>>, PrgAes128, 16>;
#[cfg(feature = "crypto-dependencies")]
impl Prio3Aes128CountVec {
pub fn new_aes128_count_vec(num_aggregators: u8, len: usize) -> Result<Self, VdafError> {
Prio3::new(num_aggregators, CountVec::new(len))
}
}
#[cfg(feature = "multithreaded")]
#[cfg(feature = "crypto-dependencies")]
#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
pub type Prio3Aes128CountVecMultithreaded = Prio3<
CountVec<Field128, ParallelSumMultithreaded<Field128, BlindPolyEval<Field128>>>,
PrgAes128,
16,
>;
#[cfg(feature = "multithreaded")]
#[cfg(feature = "crypto-dependencies")]
#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
impl Prio3Aes128CountVecMultithreaded {
pub fn new_aes128_count_vec_multithreaded(
num_aggregators: u8,
len: usize,
) -> Result<Self, VdafError> {
Prio3::new(num_aggregators, CountVec::new(len))
}
}
#[cfg(feature = "crypto-dependencies")]
pub type Prio3Aes128Sum = Prio3<Sum<Field128>, PrgAes128, 16>;
#[cfg(feature = "crypto-dependencies")]
impl Prio3Aes128Sum {
pub fn new_aes128_sum(num_aggregators: u8, bits: u32) -> Result<Self, VdafError> {
if bits > 64 {
return Err(VdafError::Uncategorized(format!(
"bit length ({}) exceeds limit for aggregate type (64)",
bits
)));
}
Prio3::new(num_aggregators, Sum::new(bits as usize)?)
}
}
#[cfg(feature = "crypto-dependencies")]
pub type Prio3Aes128FixedPointBoundedL2VecSum<Fx> = Prio3<
FixedPointBoundedL2VecSum<
Fx,
Field128,
ParallelSum<Field128, PolyEval<Field128>>,
ParallelSum<Field128, BlindPolyEval<Field128>>,
>,
PrgAes128,
16,
>;
#[cfg(feature = "crypto-dependencies")]
impl<Fx: Fixed + CompatibleFloat<Field128>> Prio3Aes128FixedPointBoundedL2VecSum<Fx> {
pub fn new_aes128_fixedpoint_boundedl2_vec_sum(
num_aggregators: u8,
entries: usize,
) -> Result<Self, VdafError> {
check_num_aggregators(num_aggregators)?;
Prio3::new(num_aggregators, FixedPointBoundedL2VecSum::new(entries)?)
}
}
#[cfg(feature = "multithreaded")]
#[cfg(feature = "crypto-dependencies")]
#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
pub type Prio3Aes128FixedPointBoundedL2VecSumMultithreaded<Fx> = Prio3<
FixedPointBoundedL2VecSum<
Fx,
Field128,
ParallelSumMultithreaded<Field128, PolyEval<Field128>>,
ParallelSumMultithreaded<Field128, BlindPolyEval<Field128>>,
>,
PrgAes128,
16,
>;
#[cfg(feature = "multithreaded")]
#[cfg(feature = "crypto-dependencies")]
#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
impl<Fx: Fixed + CompatibleFloat<Field128>> Prio3Aes128FixedPointBoundedL2VecSumMultithreaded<Fx> {
pub fn new_aes128_fixedpoint_boundedl2_vec_sum_multithreaded(
num_aggregators: u8,
entries: usize,
) -> Result<Self, VdafError> {
check_num_aggregators(num_aggregators)?;
Prio3::new(num_aggregators, FixedPointBoundedL2VecSum::new(entries)?)
}
}
#[cfg(feature = "crypto-dependencies")]
pub type Prio3Aes128Histogram = Prio3<Histogram<Field128>, PrgAes128, 16>;
#[cfg(feature = "crypto-dependencies")]
impl Prio3Aes128Histogram {
pub fn new_aes128_histogram(num_aggregators: u8, buckets: &[u64]) -> Result<Self, VdafError> {
let buckets = buckets.iter().map(|bucket| *bucket as u128).collect();
Prio3::new(num_aggregators, Histogram::new(buckets)?)
}
}
#[cfg(feature = "crypto-dependencies")]
pub type Prio3Aes128Average = Prio3<Average<Field128>, PrgAes128, 16>;
#[cfg(feature = "crypto-dependencies")]
impl Prio3Aes128Average {
pub fn new_aes128_average(num_aggregators: u8, bits: u32) -> Result<Self, VdafError> {
check_num_aggregators(num_aggregators)?;
if bits > 64 {
return Err(VdafError::Uncategorized(format!(
"bit length ({}) exceeds limit for aggregate type (64)",
bits
)));
}
Ok(Prio3 {
num_aggregators,
typ: Average::new(bits as usize)?,
phantom: PhantomData,
})
}
}
#[derive(Clone, Debug)]
pub struct Prio3<T, P, const L: usize>
where
T: Type,
P: Prg<L>,
{
num_aggregators: u8,
typ: T,
phantom: PhantomData<P>,
}
impl<T, P, const L: usize> Prio3<T, P, L>
where
T: Type,
P: Prg<L>,
{
pub fn new(num_aggregators: u8, typ: T) -> Result<Self, VdafError> {
check_num_aggregators(num_aggregators)?;
Ok(Self {
num_aggregators,
typ,
phantom: PhantomData,
})
}
pub fn output_len(&self) -> usize {
self.typ.output_len()
}
pub fn verifier_len(&self) -> usize {
self.typ.verifier_len()
}
fn derive_joint_randomness<'a>(parts: impl Iterator<Item = &'a Seed<L>>) -> Seed<L> {
let mut info = [0; VERSION.len() + 5];
info[..VERSION.len()].copy_from_slice(VERSION);
info[VERSION.len()..VERSION.len() + 4].copy_from_slice(&Self::ID.to_be_bytes());
info[VERSION.len() + 4] = 255;
let mut deriver = P::init(&[0; L]);
deriver.update(&info);
for part in parts {
deriver.update(part.as_ref());
}
deriver.into_seed()
}
fn shard_with_rand_source(
&self,
measurement: &T::Measurement,
rand_source: RandSource,
) -> Result<Vec<Prio3InputShare<T::Field, L>>, VdafError> {
let mut info = [0; DST_LEN + 1];
info[..VERSION.len()].copy_from_slice(VERSION);
info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes());
let num_aggregators = self.num_aggregators;
let input = self.typ.encode_measurement(measurement)?;
let mut helper_shares = Vec::with_capacity(num_aggregators as usize - 1);
let mut helper_joint_rand_parts = if self.typ.joint_rand_len() > 0 {
Some(Vec::with_capacity(num_aggregators as usize - 1))
} else {
None
};
let mut leader_input_share = input.clone();
for agg_id in 1..num_aggregators {
let helper = HelperShare::from_rand_source(rand_source)?;
let mut deriver = P::init(helper.joint_rand_param.blind.as_ref());
info[DST_LEN] = agg_id;
deriver.update(&info);
let prng: Prng<T::Field, _> =
Prng::from_seed_stream(P::seed_stream(&helper.input_share, &info));
for (x, y) in leader_input_share
.iter_mut()
.zip(prng)
.take(self.typ.input_len())
{
*x -= y;
deriver.update(&y.into());
}
if let Some(helper_joint_rand_parts) = helper_joint_rand_parts.as_mut() {
helper_joint_rand_parts.push(deriver.into_seed());
}
helper_shares.push(helper);
}
let leader_blind = Seed::from_rand_source(rand_source)?;
info[DST_LEN] = 0; let mut deriver = P::init(leader_blind.as_ref());
deriver.update(&info);
for x in leader_input_share.iter() {
deriver.update(&(*x).into());
}
let leader_joint_rand_seed_part = deriver.into_seed();
let joint_rand_seed = helper_joint_rand_parts.as_ref().map(|parts| {
Self::derive_joint_randomness(
std::iter::once(&leader_joint_rand_seed_part).chain(parts.iter()),
)
});
let domain_separation_tag = &info[..DST_LEN];
let joint_rand: Vec<T::Field> = joint_rand_seed
.map(|joint_rand_seed| {
let prng: Prng<T::Field, _> =
Prng::from_seed_stream(P::seed_stream(&joint_rand_seed, domain_separation_tag));
prng.take(self.typ.joint_rand_len()).collect()
})
.unwrap_or_default();
let prng: Prng<T::Field, _> = Prng::from_seed_stream(P::seed_stream(
&Seed::from_rand_source(rand_source)?,
domain_separation_tag,
));
let prove_rand: Vec<T::Field> = prng.take(self.typ.prove_rand_len()).collect();
let mut leader_proof_share = self.typ.prove(&input, &prove_rand, &joint_rand)?;
for (j, helper) in helper_shares.iter_mut().enumerate() {
info[DST_LEN] = j as u8 + 1;
let prng: Prng<T::Field, _> =
Prng::from_seed_stream(P::seed_stream(&helper.proof_share, &info));
for (x, y) in leader_proof_share
.iter_mut()
.zip(prng)
.take(self.typ.proof_len())
{
*x -= y;
}
if let Some(helper_joint_rand_parts) = helper_joint_rand_parts.as_ref() {
let mut hint = Vec::with_capacity(num_aggregators as usize - 1);
hint.push(leader_joint_rand_seed_part.clone());
hint.extend(helper_joint_rand_parts[..j].iter().cloned());
hint.extend(helper_joint_rand_parts[j + 1..].iter().cloned());
helper.joint_rand_param.seed_hint = hint;
}
}
let leader_joint_rand_param = if self.typ.joint_rand_len() > 0 {
Some(JointRandParam {
seed_hint: helper_joint_rand_parts.unwrap_or_default(),
blind: leader_blind,
})
} else {
None
};
let mut out = Vec::with_capacity(num_aggregators as usize);
out.push(Prio3InputShare {
input_share: Share::Leader(leader_input_share),
proof_share: Share::Leader(leader_proof_share),
joint_rand_param: leader_joint_rand_param,
});
for helper in helper_shares.into_iter() {
let helper_joint_rand_param = if self.typ.joint_rand_len() > 0 {
Some(helper.joint_rand_param)
} else {
None
};
out.push(Prio3InputShare {
input_share: Share::Helper(helper.input_share),
proof_share: Share::Helper(helper.proof_share),
joint_rand_param: helper_joint_rand_param,
});
}
Ok(out)
}
#[cfg(feature = "test-util")]
pub fn test_vec_shard(
&self,
measurement: &T::Measurement,
) -> Result<Vec<Prio3InputShare<T::Field, L>>, VdafError> {
self.shard_with_rand_source(measurement, |buf| {
buf.fill(1);
Ok(())
})
}
fn role_try_from(&self, agg_id: usize) -> Result<u8, VdafError> {
if agg_id >= self.num_aggregators as usize {
return Err(VdafError::Uncategorized("unexpected aggregator id".into()));
}
Ok(u8::try_from(agg_id).unwrap())
}
}
impl<T, P, const L: usize> Vdaf for Prio3<T, P, L>
where
T: Type,
P: Prg<L>,
{
const ID: u32 = T::ID;
type Measurement = T::Measurement;
type AggregateResult = T::AggregateResult;
type AggregationParam = ();
type PublicShare = ();
type InputShare = Prio3InputShare<T::Field, L>;
type OutputShare = OutputShare<T::Field>;
type AggregateShare = AggregateShare<T::Field>;
fn num_aggregators(&self) -> usize {
self.num_aggregators as usize
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Prio3InputShare<F, const L: usize> {
input_share: Share<F, L>,
proof_share: Share<F, L>,
joint_rand_param: Option<JointRandParam<L>>,
}
impl<F: FieldElement, const L: usize> Encode for Prio3InputShare<F, L> {
fn encode(&self, bytes: &mut Vec<u8>) {
if matches!(
(&self.input_share, &self.proof_share),
(Share::Leader(_), Share::Helper(_)) | (Share::Helper(_), Share::Leader(_))
) {
panic!("tried to encode input share with ambiguous encoding")
}
self.input_share.encode(bytes);
self.proof_share.encode(bytes);
if let Some(ref param) = self.joint_rand_param {
param.blind.encode(bytes);
for part in param.seed_hint.iter() {
part.encode(bytes);
}
}
}
}
impl<'a, T, P, const L: usize> ParameterizedDecode<(&'a Prio3<T, P, L>, usize)>
for Prio3InputShare<T::Field, L>
where
T: Type,
P: Prg<L>,
{
fn decode_with_param(
(prio3, agg_id): &(&'a Prio3<T, P, L>, usize),
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
let agg_id = prio3
.role_try_from(*agg_id)
.map_err(|e| CodecError::Other(Box::new(e)))?;
let (input_decoder, proof_decoder) = if agg_id == 0 {
(
ShareDecodingParameter::Leader(prio3.typ.input_len()),
ShareDecodingParameter::Leader(prio3.typ.proof_len()),
)
} else {
(
ShareDecodingParameter::Helper,
ShareDecodingParameter::Helper,
)
};
let input_share = Share::decode_with_param(&input_decoder, bytes)?;
let proof_share = Share::decode_with_param(&proof_decoder, bytes)?;
let joint_rand_param = if prio3.typ.joint_rand_len() > 0 {
let num_aggregators = prio3.num_aggregators();
let blind = Seed::decode(bytes)?;
let seed_hint = std::iter::repeat_with(|| Seed::decode(bytes))
.take(num_aggregators - 1)
.collect::<Result<Vec<_>, _>>()?;
Some(JointRandParam { blind, seed_hint })
} else {
None
};
Ok(Prio3InputShare {
input_share,
proof_share,
joint_rand_param,
})
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Prio3PrepareShare<F, const L: usize> {
verifier: Vec<F>,
joint_rand_part: Option<Seed<L>>,
}
impl<F: FieldElement, const L: usize> Encode for Prio3PrepareShare<F, L> {
fn encode(&self, bytes: &mut Vec<u8>) {
for x in &self.verifier {
x.encode(bytes);
}
if let Some(ref seed) = self.joint_rand_part {
seed.encode(bytes);
}
}
}
impl<F: FieldElement, const L: usize> ParameterizedDecode<Prio3PrepareState<F, L>>
for Prio3PrepareShare<F, L>
{
fn decode_with_param(
decoding_parameter: &Prio3PrepareState<F, L>,
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
let mut verifier = Vec::with_capacity(decoding_parameter.verifier_len);
for _ in 0..decoding_parameter.verifier_len {
verifier.push(F::decode(bytes)?);
}
let joint_rand_part = if decoding_parameter.joint_rand_seed.is_some() {
Some(Seed::decode(bytes)?)
} else {
None
};
Ok(Prio3PrepareShare {
verifier,
joint_rand_part,
})
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Prio3PrepareMessage<const L: usize> {
joint_rand_seed: Option<Seed<L>>,
}
impl<const L: usize> Encode for Prio3PrepareMessage<L> {
fn encode(&self, bytes: &mut Vec<u8>) {
if let Some(ref seed) = self.joint_rand_seed {
seed.encode(bytes);
}
}
}
impl<F: FieldElement, const L: usize> ParameterizedDecode<Prio3PrepareState<F, L>>
for Prio3PrepareMessage<L>
{
fn decode_with_param(
decoding_parameter: &Prio3PrepareState<F, L>,
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
let joint_rand_seed = if decoding_parameter.joint_rand_seed.is_some() {
Some(Seed::decode(bytes)?)
} else {
None
};
Ok(Prio3PrepareMessage { joint_rand_seed })
}
}
impl<T, P, const L: usize> Client for Prio3<T, P, L>
where
T: Type,
P: Prg<L>,
{
#[allow(clippy::type_complexity)]
fn shard(
&self,
measurement: &T::Measurement,
) -> Result<((), Vec<Prio3InputShare<T::Field, L>>), VdafError> {
self.shard_with_rand_source(measurement, getrandom::getrandom)
.map(|input_shares| ((), input_shares))
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Prio3PrepareState<F, const L: usize> {
input_share: Share<F, L>,
joint_rand_seed: Option<Seed<L>>,
agg_id: u8,
verifier_len: usize,
}
impl<F: FieldElement, const L: usize> Encode for Prio3PrepareState<F, L> {
fn encode(&self, bytes: &mut Vec<u8>) {
self.input_share.encode(bytes);
if let Some(ref seed) = self.joint_rand_seed {
seed.encode(bytes);
}
}
}
impl<'a, T, P, const L: usize> ParameterizedDecode<(&'a Prio3<T, P, L>, usize)>
for Prio3PrepareState<T::Field, L>
where
T: Type,
P: Prg<L>,
{
fn decode_with_param(
(prio3, agg_id): &(&'a Prio3<T, P, L>, usize),
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
let agg_id = prio3
.role_try_from(*agg_id)
.map_err(|e| CodecError::Other(Box::new(e)))?;
let share_decoder = if agg_id == 0 {
ShareDecodingParameter::Leader(prio3.typ.input_len())
} else {
ShareDecodingParameter::Helper
};
let input_share = Share::decode_with_param(&share_decoder, bytes)?;
let joint_rand_seed = if prio3.typ.joint_rand_len() > 0 {
Some(Seed::decode(bytes)?)
} else {
None
};
Ok(Self {
input_share,
joint_rand_seed,
agg_id,
verifier_len: prio3.typ.verifier_len(),
})
}
}
impl<T, P, const L: usize> Aggregator<L> for Prio3<T, P, L>
where
T: Type,
P: Prg<L>,
{
type PrepareState = Prio3PrepareState<T::Field, L>;
type PrepareShare = Prio3PrepareShare<T::Field, L>;
type PrepareMessage = Prio3PrepareMessage<L>;
#[allow(clippy::type_complexity)]
fn prepare_init(
&self,
verify_key: &[u8; L],
agg_id: usize,
_agg_param: &(),
nonce: &[u8],
_public_share: &(),
msg: &Prio3InputShare<T::Field, L>,
) -> Result<
(
Prio3PrepareState<T::Field, L>,
Prio3PrepareShare<T::Field, L>,
),
VdafError,
> {
let agg_id = self.role_try_from(agg_id)?;
let mut info = [0; DST_LEN + 1];
info[..VERSION.len()].copy_from_slice(VERSION);
info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes());
info[DST_LEN] = agg_id;
let domain_separation_tag = &info[..DST_LEN];
let mut deriver = P::init(verify_key);
deriver.update(domain_separation_tag);
deriver.update(&[255]);
deriver.update(nonce);
let query_rand_prng = Prng::from_seed_stream(deriver.into_seed_stream());
let expanded_input_share: Option<Vec<T::Field>> = match msg.input_share {
Share::Leader(_) => None,
Share::Helper(ref seed) => {
let prng = Prng::from_seed_stream(P::seed_stream(seed, &info));
Some(prng.take(self.typ.input_len()).collect())
}
};
let input_share = match msg.input_share {
Share::Leader(ref data) => data,
Share::Helper(_) => expanded_input_share.as_ref().unwrap(),
};
let expanded_proof_share: Option<Vec<T::Field>> = match msg.proof_share {
Share::Leader(_) => None,
Share::Helper(ref seed) => {
let prng = Prng::from_seed_stream(P::seed_stream(seed, &info));
Some(prng.take(self.typ.proof_len()).collect())
}
};
let proof_share = match msg.proof_share {
Share::Leader(ref data) => data,
Share::Helper(_) => expanded_proof_share.as_ref().unwrap(),
};
let (joint_rand_seed, joint_rand_seed_part, joint_rand) = if self.typ.joint_rand_len() > 0 {
let mut deriver = P::init(msg.joint_rand_param.as_ref().unwrap().blind.as_ref());
deriver.update(&info);
for x in input_share {
deriver.update(&(*x).into());
}
let joint_rand_seed_part = deriver.into_seed();
let hints = &msg.joint_rand_param.as_ref().unwrap().seed_hint;
let joint_rand_seed = Self::derive_joint_randomness(
hints[..agg_id as usize]
.iter()
.chain(std::iter::once(&joint_rand_seed_part))
.chain(hints[agg_id as usize..].iter()),
);
let prng: Prng<T::Field, _> =
Prng::from_seed_stream(P::seed_stream(&joint_rand_seed, domain_separation_tag));
(
Some(joint_rand_seed),
Some(joint_rand_seed_part),
prng.take(self.typ.joint_rand_len()).collect(),
)
} else {
(None, None, Vec::new())
};
let query_rand: Vec<T::Field> = query_rand_prng.take(self.typ.query_rand_len()).collect();
let verifier_share = self.typ.query(
input_share,
proof_share,
&query_rand,
&joint_rand,
self.num_aggregators as usize,
)?;
Ok((
Prio3PrepareState {
input_share: msg.input_share.clone(),
joint_rand_seed,
agg_id,
verifier_len: verifier_share.len(),
},
Prio3PrepareShare {
verifier: verifier_share,
joint_rand_part: joint_rand_seed_part,
},
))
}
fn prepare_preprocess<M: IntoIterator<Item = Prio3PrepareShare<T::Field, L>>>(
&self,
inputs: M,
) -> Result<Prio3PrepareMessage<L>, VdafError> {
let mut verifier = vec![T::Field::zero(); self.typ.verifier_len()];
let mut joint_rand_parts = Vec::with_capacity(self.num_aggregators());
let mut count = 0;
for share in inputs.into_iter() {
count += 1;
if share.verifier.len() != verifier.len() {
return Err(VdafError::Uncategorized(format!(
"unexpected verifier share length: got {}; want {}",
share.verifier.len(),
verifier.len(),
)));
}
if self.typ.joint_rand_len() > 0 {
let joint_rand_seed_part = share.joint_rand_part.unwrap();
joint_rand_parts.push(joint_rand_seed_part);
}
for (x, y) in verifier.iter_mut().zip(share.verifier) {
*x += y;
}
}
if count != self.num_aggregators {
return Err(VdafError::Uncategorized(format!(
"unexpected message count: got {}; want {}",
count, self.num_aggregators,
)));
}
match self.typ.decide(&verifier) {
Ok(true) => (),
Ok(false) => {
return Err(VdafError::Uncategorized(
"proof verifier check failed".into(),
))
}
Err(err) => return Err(VdafError::from(err)),
};
let joint_rand_seed = if self.typ.joint_rand_len() > 0 {
Some(Self::derive_joint_randomness(joint_rand_parts.iter()))
} else {
None
};
Ok(Prio3PrepareMessage { joint_rand_seed })
}
fn prepare_step(
&self,
step: Prio3PrepareState<T::Field, L>,
msg: Prio3PrepareMessage<L>,
) -> Result<PrepareTransition<Self, L>, VdafError> {
if self.typ.joint_rand_len() > 0 {
if step.joint_rand_seed.as_ref().unwrap() != msg.joint_rand_seed.as_ref().unwrap() {
return Err(VdafError::Uncategorized(
"joint randomness mismatch".to_string(),
));
}
}
let input_share = match step.input_share {
Share::Leader(data) => data,
Share::Helper(seed) => {
let mut info = [0; DST_LEN + 1];
info[..VERSION.len()].copy_from_slice(VERSION);
info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes());
info[DST_LEN] = step.agg_id;
let prng = Prng::from_seed_stream(P::seed_stream(&seed, &info));
prng.take(self.typ.input_len()).collect()
}
};
let output_share = match self.typ.truncate(input_share) {
Ok(data) => OutputShare(data),
Err(err) => {
return Err(VdafError::from(err));
}
};
Ok(PrepareTransition::Finish(output_share))
}
fn aggregate<It: IntoIterator<Item = OutputShare<T::Field>>>(
&self,
_agg_param: &(),
output_shares: It,
) -> Result<AggregateShare<T::Field>, VdafError> {
let mut agg_share = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]);
for output_share in output_shares.into_iter() {
agg_share.accumulate(&output_share)?;
}
Ok(agg_share)
}
}
impl<T, P, const L: usize> Collector for Prio3<T, P, L>
where
T: Type,
P: Prg<L>,
{
fn unshard<It: IntoIterator<Item = AggregateShare<T::Field>>>(
&self,
_agg_param: &(),
agg_shares: It,
num_measurements: usize,
) -> Result<T::AggregateResult, VdafError> {
let mut agg = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]);
for agg_share in agg_shares.into_iter() {
agg.merge(&agg_share)?;
}
Ok(self.typ.decode_result(&agg.0, num_measurements)?)
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
struct JointRandParam<const L: usize> {
seed_hint: Vec<Seed<L>>,
blind: Seed<L>,
}
#[derive(Clone)]
struct HelperShare<const L: usize> {
input_share: Seed<L>,
proof_share: Seed<L>,
joint_rand_param: JointRandParam<L>,
}
impl<const L: usize> HelperShare<L> {
fn from_rand_source(rand_source: RandSource) -> Result<Self, VdafError> {
Ok(HelperShare {
input_share: Seed::from_rand_source(rand_source)?,
proof_share: Seed::from_rand_source(rand_source)?,
joint_rand_param: JointRandParam {
seed_hint: Vec::new(),
blind: Seed::from_rand_source(rand_source)?,
},
})
}
}
fn check_num_aggregators(num_aggregators: u8) -> Result<(), VdafError> {
if num_aggregators == 0 {
return Err(VdafError::Uncategorized(format!(
"at least one aggregator is required; got {}",
num_aggregators
)));
} else if num_aggregators > 254 {
return Err(VdafError::Uncategorized(format!(
"number of aggregators must not exceed 254; got {}",
num_aggregators
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::flp::gadgets::ParallelSumGadget;
use crate::vdaf::{run_vdaf, run_vdaf_prepare};
use assert_matches::assert_matches;
use fixed::types::extra::{U15, U31, U63};
use fixed::{FixedI16, FixedI32, FixedI64};
use fixed_macro::fixed;
use rand::prelude::*;
#[test]
fn test_prio3_count() {
let prio3 = Prio3::new_aes128_count(2).unwrap();
assert_eq!(run_vdaf(&prio3, &(), [1, 0, 0, 1, 1]).unwrap(), 3);
let mut verify_key = [0; 16];
thread_rng().fill(&mut verify_key[..]);
let nonce = b"This is a good nonce.";
let (public_share, input_shares) = prio3.shard(&0).unwrap();
run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares).unwrap();
let (public_share, input_shares) = prio3.shard(&1).unwrap();
run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares).unwrap();
test_prepare_state_serialization(&prio3, &1).unwrap();
let prio3_extra_helper = Prio3::new_aes128_count(3).unwrap();
assert_eq!(
run_vdaf(&prio3_extra_helper, &(), [1, 0, 0, 1, 1]).unwrap(),
3,
);
}
#[test]
fn test_prio3_sum() {
let prio3 = Prio3::new_aes128_sum(3, 16).unwrap();
assert_eq!(
run_vdaf(&prio3, &(), [0, (1 << 16) - 1, 0, 1, 1]).unwrap(),
(1 << 16) + 1
);
let mut verify_key = [0; 16];
thread_rng().fill(&mut verify_key[..]);
let nonce = b"This is a good nonce.";
let (public_share, mut input_shares) = prio3.shard(&1).unwrap();
input_shares[0].joint_rand_param.as_mut().unwrap().blind.0[0] ^= 255;
let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
assert_matches!(result, Err(VdafError::Uncategorized(_)));
let (public_share, mut input_shares) = prio3.shard(&1).unwrap();
input_shares[0].joint_rand_param.as_mut().unwrap().seed_hint[0].0[0] ^= 255;
let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
assert_matches!(result, Err(VdafError::Uncategorized(_)));
let (public_share, mut input_shares) = prio3.shard(&1).unwrap();
assert_matches!(input_shares[0].input_share, Share::Leader(ref mut data) => {
data[0] += Field128::one();
});
let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
assert_matches!(result, Err(VdafError::Uncategorized(_)));
let (public_share, mut input_shares) = prio3.shard(&1).unwrap();
assert_matches!(input_shares[0].proof_share, Share::Leader(ref mut data) => {
data[0] += Field128::one();
});
let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
assert_matches!(result, Err(VdafError::Uncategorized(_)));
test_prepare_state_serialization(&prio3, &1).unwrap();
}
#[test]
fn test_prio3_countvec() {
let prio3 = Prio3::new_aes128_count_vec(2, 20).unwrap();
assert_eq!(
run_vdaf(
&prio3,
&(),
[vec![
0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,
]]
)
.unwrap(),
vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,]
);
}
#[test]
#[cfg(feature = "multithreaded")]
fn test_prio3_countvec_multithreaded() {
let prio3 = Prio3::new_aes128_count_vec_multithreaded(2, 20).unwrap();
assert_eq!(
run_vdaf(
&prio3,
&(),
[vec![
0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,
]]
)
.unwrap(),
vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,]
);
}
#[test]
fn test_prio3_bounded_fpvec_sum() {
type P<Fx> = Prio3Aes128FixedPointBoundedL2VecSum<Fx>;
let ctor_16 = P::<FixedI16<U15>>::new_aes128_fixedpoint_boundedl2_vec_sum;
let ctor_32 = P::<FixedI32<U31>>::new_aes128_fixedpoint_boundedl2_vec_sum;
let ctor_64 = P::<FixedI64<U63>>::new_aes128_fixedpoint_boundedl2_vec_sum;
#[cfg(feature = "multithreaded")]
type PM<Fx> = Prio3Aes128FixedPointBoundedL2VecSumMultithreaded<Fx>;
#[cfg(feature = "multithreaded")]
let ctor_mt_16 = PM::<FixedI16<U15>>::new_aes128_fixedpoint_boundedl2_vec_sum_multithreaded;
#[cfg(feature = "multithreaded")]
let ctor_mt_32 = PM::<FixedI32<U31>>::new_aes128_fixedpoint_boundedl2_vec_sum_multithreaded;
#[cfg(feature = "multithreaded")]
let ctor_mt_64 = PM::<FixedI64<U63>>::new_aes128_fixedpoint_boundedl2_vec_sum_multithreaded;
{
let fp16_4_inv = fixed!(0.25: I1F15);
let fp16_8_inv = fixed!(0.125: I1F15);
let fp16_16_inv = fixed!(0.0625: I1F15);
{
let prio3_16 = ctor_16(2, 3).unwrap();
test_fixed(fp16_4_inv, fp16_8_inv, fp16_16_inv, prio3_16);
}
#[cfg(feature = "multithreaded")]
{
let prio3_16_mt = ctor_mt_16(2, 3).unwrap();
test_fixed(fp16_4_inv, fp16_8_inv, fp16_16_inv, prio3_16_mt);
}
}
{
let fp32_4_inv = fixed!(0.25: I1F31);
let fp32_8_inv = fixed!(0.125: I1F31);
let fp32_16_inv = fixed!(0.0625: I1F31);
{
let prio3_32 = ctor_32(2, 3).unwrap();
test_fixed(fp32_4_inv, fp32_8_inv, fp32_16_inv, prio3_32);
}
#[cfg(feature = "multithreaded")]
{
let prio3_32_mt = ctor_mt_32(2, 3).unwrap();
test_fixed(fp32_4_inv, fp32_8_inv, fp32_16_inv, prio3_32_mt);
}
}
{
let fp64_4_inv = fixed!(0.25: I1F63);
let fp64_8_inv = fixed!(0.125: I1F63);
let fp64_16_inv = fixed!(0.0625: I1F63);
{
let prio3_64 = ctor_64(2, 3).unwrap();
test_fixed(fp64_4_inv, fp64_8_inv, fp64_16_inv, prio3_64);
}
#[cfg(feature = "multithreaded")]
{
let prio3_64_mt = ctor_mt_64(2, 3).unwrap();
test_fixed(fp64_4_inv, fp64_8_inv, fp64_16_inv, prio3_64_mt);
}
}
fn test_fixed<Fx, PE, BPE>(
fp_4_inv: Fx,
fp_8_inv: Fx,
fp_16_inv: Fx,
prio3: Prio3<FixedPointBoundedL2VecSum<Fx, Field128, PE, BPE>, PrgAes128, 16>,
) where
Fx: Fixed + CompatibleFloat<Field128> + std::ops::Neg<Output = Fx>,
PE: Eq + ParallelSumGadget<Field128, PolyEval<Field128>> + Clone + 'static,
BPE: Eq + ParallelSumGadget<Field128, BlindPolyEval<Field128>> + Clone + 'static,
{
let fp_vec1 = vec![fp_4_inv, fp_8_inv, fp_16_inv];
let fp_vec2 = vec![fp_4_inv, fp_8_inv, fp_16_inv];
let fp_vec3 = vec![-fp_4_inv, -fp_8_inv, -fp_16_inv];
let fp_vec4 = vec![-fp_4_inv, -fp_8_inv, -fp_16_inv];
let fp_vec5 = vec![fp_4_inv, -fp_8_inv, -fp_16_inv];
let fp_vec6 = vec![fp_4_inv, fp_8_inv, fp_16_inv];
let fp_list = [fp_vec1, fp_vec2];
assert_eq!(
run_vdaf(&prio3, &(), fp_list).unwrap(),
vec!(0.5, 0.25, 0.125),
);
let fp_list2 = [fp_vec3, fp_vec4];
assert_eq!(
run_vdaf(&prio3, &(), fp_list2).unwrap(),
vec!(-0.5, -0.25, -0.125),
);
let fp_list3 = [fp_vec5, fp_vec6];
assert_eq!(
run_vdaf(&prio3, &(), fp_list3).unwrap(),
vec!(0.5, 0.0, 0.0),
);
let mut verify_key = [0; 16];
thread_rng().fill(&mut verify_key[..]);
let nonce = b"This is a good nonce.";
let (public_share, mut input_shares) =
prio3.shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv]).unwrap();
input_shares[0].joint_rand_param.as_mut().unwrap().blind.0[0] ^= 255;
let result =
run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
assert_matches!(result, Err(VdafError::Uncategorized(_)));
let (public_share, mut input_shares) =
prio3.shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv]).unwrap();
input_shares[0].joint_rand_param.as_mut().unwrap().seed_hint[0].0[0] ^= 255;
let result =
run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
assert_matches!(result, Err(VdafError::Uncategorized(_)));
let (public_share, mut input_shares) =
prio3.shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv]).unwrap();
assert_matches!(input_shares[0].input_share, Share::Leader(ref mut data) => {
data[0] += Field128::one();
});
let result =
run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
assert_matches!(result, Err(VdafError::Uncategorized(_)));
let (public_share, mut input_shares) =
prio3.shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv]).unwrap();
assert_matches!(input_shares[0].proof_share, Share::Leader(ref mut data) => {
data[0] += Field128::one();
});
let result =
run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
assert_matches!(result, Err(VdafError::Uncategorized(_)));
test_prepare_state_serialization(&prio3, &vec![fp_4_inv, fp_8_inv, fp_16_inv]).unwrap();
}
}
#[test]
fn test_prio3_histogram() {
let prio3 = Prio3::new_aes128_histogram(2, &[0, 10, 20]).unwrap();
assert_eq!(
run_vdaf(&prio3, &(), [0, 10, 20, 9999]).unwrap(),
vec![1, 1, 1, 1]
);
assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]);
assert_eq!(run_vdaf(&prio3, &(), [5]).unwrap(), vec![0, 1, 0, 0]);
assert_eq!(run_vdaf(&prio3, &(), [10]).unwrap(), vec![0, 1, 0, 0]);
assert_eq!(run_vdaf(&prio3, &(), [15]).unwrap(), vec![0, 0, 1, 0]);
assert_eq!(run_vdaf(&prio3, &(), [20]).unwrap(), vec![0, 0, 1, 0]);
assert_eq!(run_vdaf(&prio3, &(), [25]).unwrap(), vec![0, 0, 0, 1]);
test_prepare_state_serialization(&prio3, &23).unwrap();
}
#[test]
fn test_prio3_average() {
let prio3 = Prio3::new_aes128_average(2, 64).unwrap();
assert_eq!(run_vdaf(&prio3, &(), [17, 8]).unwrap(), 12.5f64);
assert_eq!(run_vdaf(&prio3, &(), [1, 1, 1, 1]).unwrap(), 1f64);
assert_eq!(run_vdaf(&prio3, &(), [0, 0, 0, 1]).unwrap(), 0.25f64);
assert_eq!(
run_vdaf(&prio3, &(), [1, 11, 111, 1111, 3, 8]).unwrap(),
207.5f64
);
}
#[test]
fn test_prio3_input_share() {
let prio3 = Prio3::new_aes128_sum(5, 16).unwrap();
let (_public_share, input_shares) = prio3.shard(&1).unwrap();
for (i, x) in input_shares.iter().enumerate() {
for (j, y) in input_shares.iter().enumerate() {
if i != j {
if let (Share::Helper(left), Share::Helper(right)) =
(&x.input_share, &y.input_share)
{
assert_ne!(left, right);
}
if let (Share::Helper(left), Share::Helper(right)) =
(&x.proof_share, &y.proof_share)
{
assert_ne!(left, right);
}
assert_ne!(x.joint_rand_param, y.joint_rand_param);
}
}
}
}
fn test_prepare_state_serialization<T, P, const L: usize>(
prio3: &Prio3<T, P, L>,
measurement: &T::Measurement,
) -> Result<(), VdafError>
where
T: Type,
P: Prg<L>,
{
let mut verify_key = [0; L];
thread_rng().fill(&mut verify_key[..]);
let (public_share, input_shares) = prio3.shard(measurement)?;
for (agg_id, input_share) in input_shares.iter().enumerate() {
let (want, _msg) =
prio3.prepare_init(&verify_key, agg_id, &(), &[], &public_share, input_share)?;
let got =
Prio3PrepareState::get_decoded_with_param(&(prio3, agg_id), &want.get_encoded())
.expect("failed to decode prepare step");
assert_eq!(got, want);
}
Ok(())
}
}