prio 0.10.0

Implementation of the Prio aggregation system core: https://crypto.stanford.edu/prio/
Documentation
// SPDX-License-Identifier: MPL-2.0

//! Port of the ENPA Prio system to a VDAF. It is backwards compatible with
//! [`Client`](crate::client::Client) and [`Server`](crate::server::Server).

use crate::{
    client as v2_client,
    codec::{CodecError, Decode, Encode, ParameterizedDecode},
    field::{FieldElement, FieldPrio2},
    prng::Prng,
    server as v2_server,
    util::proof_length,
    vdaf::{
        prg::Seed, Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare,
        PrepareTransition, Share, ShareDecodingParameter, Vdaf, VdafError,
    },
};
use ring::hmac;
use std::{
    convert::{TryFrom, TryInto},
    io::Cursor,
};

/// The Prio2 VDAF. It supports the same measurement type as
/// [`Prio3Aes128CountVec`](crate::vdaf::prio3::Prio3Aes128CountVec) but uses the proof system
/// and finite field deployed in ENPA.
#[derive(Clone, Debug)]
pub struct Prio2 {
    input_len: usize,
}

impl Prio2 {
    /// Returns an instance of the VDAF for the given input length.
    pub fn new(input_len: usize) -> Result<Self, VdafError> {
        let n = (input_len + 1).next_power_of_two();
        if let Ok(size) = u32::try_from(2 * n) {
            if size > FieldPrio2::generator_order() {
                return Err(VdafError::Uncategorized(
                    "input size exceeds field capacity".into(),
                ));
            }
        } else {
            return Err(VdafError::Uncategorized(
                "input size exceeds memory capacity".into(),
            ));
        }

        Ok(Prio2 { input_len })
    }

    /// Prepare an input share for aggregation using the given field element `query_rand` to
    /// compute the verifier share.
    ///
    /// In the [`Aggregator`](crate::vdaf::Aggregator) trait implementation for [`Prio2`], the
    /// query randomness is computed jointly by the Aggregators. This method is designed to be used
    /// in applications, like ENPA, in which the query randomness is instead chosen by a
    /// third-party.
    pub fn prepare_init_with_query_rand(
        &self,
        query_rand: FieldPrio2,
        input_share: &Share<FieldPrio2, 32>,
        is_leader: bool,
    ) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> {
        let expanded_data: Option<Vec<FieldPrio2>> = match input_share {
            Share::Leader(_) => None,
            Share::Helper(ref seed) => {
                let prng = Prng::from_prio2_seed(seed.as_ref());
                Some(prng.take(proof_length(self.input_len)).collect())
            }
        };
        let data = match input_share {
            Share::Leader(ref data) => data,
            Share::Helper(_) => expanded_data.as_ref().unwrap(),
        };

        let mut mem = v2_server::ValidationMemory::new(self.input_len);
        let verifier_share = v2_server::generate_verification_message(
            self.input_len,
            query_rand,
            data, // Combined input and proof shares
            is_leader,
            &mut mem,
        )
        .map_err(|e| VdafError::Uncategorized(e.to_string()))?;

        Ok((
            Prio2PrepareState(input_share.truncated(self.input_len)),
            Prio2PrepareShare(verifier_share),
        ))
    }
}

impl Vdaf for Prio2 {
    const ID: u32 = 0xFFFF0000;
    type Measurement = Vec<u32>;
    type AggregateResult = Vec<u32>;
    type AggregationParam = ();
    type PublicShare = ();
    type InputShare = Share<FieldPrio2, 32>;
    type OutputShare = OutputShare<FieldPrio2>;
    type AggregateShare = AggregateShare<FieldPrio2>;

    fn num_aggregators(&self) -> usize {
        // Prio2 can easily be extended to support more than two Aggregators.
        2
    }
}

impl Client for Prio2 {
    fn shard(&self, measurement: &Vec<u32>) -> Result<((), Vec<Share<FieldPrio2, 32>>), VdafError> {
        if measurement.len() != self.input_len {
            return Err(VdafError::Uncategorized("incorrect input length".into()));
        }
        let mut input: Vec<FieldPrio2> = Vec::with_capacity(measurement.len());
        for int in measurement {
            input.push((*int).into());
        }

        let mut mem = v2_client::ClientMemory::new(self.input_len)?;
        let copy_data = |share_data: &mut [FieldPrio2]| {
            share_data[..].clone_from_slice(&input);
        };
        let mut leader_data = mem.prove_with(self.input_len, copy_data);

        let helper_seed = Seed::generate()?;
        let helper_prng = Prng::from_prio2_seed(helper_seed.as_ref());
        for (s1, d) in leader_data.iter_mut().zip(helper_prng.into_iter()) {
            *s1 -= d;
        }

        Ok((
            (),
            vec![Share::Leader(leader_data), Share::Helper(helper_seed)],
        ))
    }
}

/// State of each [`Aggregator`](crate::vdaf::Aggregator) during the Preparation phase.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Prio2PrepareState(Share<FieldPrio2, 32>);

impl Encode for Prio2PrepareState {
    fn encode(&self, bytes: &mut Vec<u8>) {
        self.0.encode(bytes);
    }
}

impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Prio2PrepareState {
    fn decode_with_param(
        (prio2, agg_id): &(&'a Prio2, usize),
        bytes: &mut Cursor<&[u8]>,
    ) -> Result<Self, CodecError> {
        let share_decoder = if *agg_id == 0 {
            ShareDecodingParameter::Leader(prio2.input_len)
        } else {
            ShareDecodingParameter::Helper
        };
        let out_share = Share::decode_with_param(&share_decoder, bytes)?;
        Ok(Self(out_share))
    }
}

/// Message emitted by each [`Aggregator`](crate::vdaf::Aggregator) during the Preparation phase.
#[derive(Clone, Debug)]
pub struct Prio2PrepareShare(v2_server::VerificationMessage<FieldPrio2>);

impl Encode for Prio2PrepareShare {
    fn encode(&self, bytes: &mut Vec<u8>) {
        self.0.f_r.encode(bytes);
        self.0.g_r.encode(bytes);
        self.0.h_r.encode(bytes);
    }
}

impl ParameterizedDecode<Prio2PrepareState> for Prio2PrepareShare {
    fn decode_with_param(
        _state: &Prio2PrepareState,
        bytes: &mut Cursor<&[u8]>,
    ) -> Result<Self, CodecError> {
        Ok(Self(v2_server::VerificationMessage {
            f_r: FieldPrio2::decode(bytes)?,
            g_r: FieldPrio2::decode(bytes)?,
            h_r: FieldPrio2::decode(bytes)?,
        }))
    }
}

impl Aggregator<32> for Prio2 {
    type PrepareState = Prio2PrepareState;
    type PrepareShare = Prio2PrepareShare;
    type PrepareMessage = ();

    fn prepare_init(
        &self,
        agg_key: &[u8; 32],
        agg_id: usize,
        _agg_param: &(),
        nonce: &[u8],
        _public_share: &Self::PublicShare,
        input_share: &Share<FieldPrio2, 32>,
    ) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> {
        let is_leader = role_try_from(agg_id)?;

        // In the ENPA Prio system, the query randomness is generated by a third party and
        // distributed to the Aggregators after they receive their input shares. In a VDAF, shared
        // randomness is derived from a nonce selected by the client. For Prio2 we compute the
        // query using HMAC-SHA256 evaluated over the nonce.
        let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, agg_key);
        let hmac_tag = hmac::sign(&hmac_key, nonce);
        let query_rand = Prng::from_prio2_seed(hmac_tag.as_ref().try_into().unwrap())
            .next()
            .unwrap();

        self.prepare_init_with_query_rand(query_rand, input_share, is_leader)
    }

    fn prepare_preprocess<M: IntoIterator<Item = Prio2PrepareShare>>(
        &self,
        inputs: M,
    ) -> Result<(), VdafError> {
        let verifier_shares: Vec<v2_server::VerificationMessage<FieldPrio2>> =
            inputs.into_iter().map(|msg| msg.0).collect();
        if verifier_shares.len() != 2 {
            return Err(VdafError::Uncategorized(
                "wrong number of verifier shares".into(),
            ));
        }

        if !v2_server::is_valid_share(&verifier_shares[0], &verifier_shares[1]) {
            return Err(VdafError::Uncategorized(
                "proof verifier check failed".into(),
            ));
        }

        Ok(())
    }

    fn prepare_step(
        &self,
        state: Prio2PrepareState,
        _input: (),
    ) -> Result<PrepareTransition<Self, 32>, VdafError> {
        let data = match state.0 {
            Share::Leader(data) => data,
            Share::Helper(seed) => {
                let prng = Prng::from_prio2_seed(seed.as_ref());
                prng.take(self.input_len).collect()
            }
        };
        Ok(PrepareTransition::Finish(OutputShare::from(data)))
    }

    fn aggregate<M: IntoIterator<Item = OutputShare<FieldPrio2>>>(
        &self,
        _agg_param: &(),
        out_shares: M,
    ) -> Result<AggregateShare<FieldPrio2>, VdafError> {
        let mut agg_share = AggregateShare(vec![FieldPrio2::zero(); self.input_len]);
        for out_share in out_shares.into_iter() {
            agg_share.accumulate(&out_share)?;
        }

        Ok(agg_share)
    }
}

impl Collector for Prio2 {
    fn unshard<M: IntoIterator<Item = AggregateShare<FieldPrio2>>>(
        &self,
        _agg_param: &(),
        agg_shares: M,
        _num_measurements: usize,
    ) -> Result<Vec<u32>, VdafError> {
        let mut agg = AggregateShare(vec![FieldPrio2::zero(); self.input_len]);
        for agg_share in agg_shares.into_iter() {
            agg.merge(&agg_share)?;
        }

        Ok(agg.0.into_iter().map(u32::from).collect())
    }
}

impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Share<FieldPrio2, 32> {
    fn decode_with_param(
        (prio2, agg_id): &(&'a Prio2, usize),
        bytes: &mut Cursor<&[u8]>,
    ) -> Result<Self, CodecError> {
        let is_leader = role_try_from(*agg_id).map_err(|e| CodecError::Other(Box::new(e)))?;
        let decoder = if is_leader {
            ShareDecodingParameter::Leader(proof_length(prio2.input_len))
        } else {
            ShareDecodingParameter::Helper
        };

        Share::decode_with_param(&decoder, bytes)
    }
}

fn role_try_from(agg_id: usize) -> Result<bool, VdafError> {
    match agg_id {
        0 => Ok(true),
        1 => Ok(false),
        _ => Err(VdafError::Uncategorized("unexpected aggregator id".into())),
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{
        client::encode_simple,
        encrypt::{decrypt_share, encrypt_share, PrivateKey, PublicKey},
        field::random_vector,
        server::Server,
        vdaf::{run_vdaf, run_vdaf_prepare},
    };
    use rand::prelude::*;

    const PRIV_KEY1: &str = "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw==";
    const PRIV_KEY2: &str = "BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhFLMQIQoRwDVaW64g/WTdcxT4rDULoycUNFB60LER6hPEHg/ObBnRPV1rwS3nj9Bj0tbjVPPyL9p8QW8B+w==";

    #[test]
    fn run_prio2() {
        let prio2 = Prio2::new(6).unwrap();

        assert_eq!(
            run_vdaf(
                &prio2,
                &(),
                [
                    vec![0, 0, 0, 0, 1, 0],
                    vec![0, 1, 0, 0, 0, 0],
                    vec![0, 1, 1, 0, 0, 0],
                    vec![1, 1, 1, 0, 0, 0],
                    vec![0, 0, 0, 0, 1, 1],
                ]
            )
            .unwrap(),
            vec![1, 3, 2, 0, 2, 1],
        );
    }

    #[test]
    fn enpa_client_interop() {
        let mut rng = thread_rng();
        let priv_key1 = PrivateKey::from_base64(PRIV_KEY1).unwrap();
        let priv_key2 = PrivateKey::from_base64(PRIV_KEY2).unwrap();
        let pub_key1 = PublicKey::from(&priv_key1);
        let pub_key2 = PublicKey::from(&priv_key2);

        let data: Vec<FieldPrio2> = [0, 0, 1, 1, 0]
            .iter()
            .map(|x| FieldPrio2::from(*x))
            .collect();
        let (encrypted_input_share1, encrypted_input_share2) =
            encode_simple(&data, pub_key1, pub_key2).unwrap();

        let input_share1 = decrypt_share(&encrypted_input_share1, &priv_key1).unwrap();
        let input_share2 = decrypt_share(&encrypted_input_share2, &priv_key2).unwrap();

        let prio2 = Prio2::new(data.len()).unwrap();
        let input_shares = vec![
            Share::get_decoded_with_param(&(&prio2, 0), &input_share1).unwrap(),
            Share::get_decoded_with_param(&(&prio2, 1), &input_share2).unwrap(),
        ];

        let verify_key = rng.gen();
        let mut nonce = [0; 16];
        rng.fill(&mut nonce);
        run_vdaf_prepare(&prio2, &verify_key, &(), &nonce, (), input_shares).unwrap();
    }

    #[test]
    fn enpa_server_interop() {
        let priv_key1 = PrivateKey::from_base64(PRIV_KEY1).unwrap();
        let priv_key2 = PrivateKey::from_base64(PRIV_KEY2).unwrap();
        let pub_key1 = PublicKey::from(&priv_key1);
        let pub_key2 = PublicKey::from(&priv_key2);

        let data = vec![0, 0, 1, 1, 0];
        let prio2 = Prio2::new(data.len()).unwrap();
        let (_public_share, input_shares) = prio2.shard(&data).unwrap();

        let encrypted_input_share1 =
            encrypt_share(&input_shares[0].get_encoded(), &pub_key1).unwrap();
        let encrypted_input_share2 =
            encrypt_share(&input_shares[1].get_encoded(), &pub_key2).unwrap();

        let mut server1 = Server::new(data.len(), true, priv_key1).unwrap();
        let mut server2 = Server::new(data.len(), false, priv_key2).unwrap();

        let eval_at: FieldPrio2 = random_vector(1).unwrap()[0];
        let verifier1 = server1
            .generate_verification_message(eval_at, &encrypted_input_share1)
            .unwrap();
        let verifier2 = server2
            .generate_verification_message(eval_at, &encrypted_input_share2)
            .unwrap();

        server1
            .aggregate(&encrypted_input_share1, &verifier1, &verifier2)
            .unwrap();
        server2
            .aggregate(&encrypted_input_share2, &verifier1, &verifier2)
            .unwrap();
    }

    #[test]
    fn prepare_state_serialization() {
        let mut verify_key = [0; 32];
        thread_rng().fill(&mut verify_key[..]);
        let data = vec![0, 0, 1, 1, 0];
        let prio2 = Prio2::new(data.len()).unwrap();
        let (public_share, input_shares) = prio2.shard(&data).unwrap();
        for (agg_id, input_share) in input_shares.iter().enumerate() {
            let (want, _msg) = prio2
                .prepare_init(&verify_key, agg_id, &(), &[], &public_share, input_share)
                .unwrap();
            let got =
                Prio2PrepareState::get_decoded_with_param(&(&prio2, agg_id), &want.get_encoded())
                    .expect("failed to decode prepare step");
            assert_eq!(got, want);
        }
    }
}