use crate::{
codec::CodecError,
field::NttFriendlyFieldElement,
polynomial::{ntt_get_roots, poly_ntt, PolyNttTempMemory},
prng::Prng,
vdaf::{xof::SeedStreamAes128, VdafError},
};
use rand::{rng, Rng};
use std::convert::TryFrom;
#[derive(Debug, thiserror::Error)]
pub enum SerializeError {
#[error("serialized input has wrong length")]
UnpackInputSizeMismatch,
#[error(transparent)]
Codec(#[from] CodecError),
}
#[derive(Debug)]
pub(crate) struct ClientMemory<F> {
prng: Prng<F, SeedStreamAes128>,
points_f: Vec<F>,
points_g: Vec<F>,
evals_f: Vec<F>,
evals_g: Vec<F>,
roots_2n: Vec<F>,
roots_n_inverted: Vec<F>,
ntt_memory: PolyNttTempMemory<F>,
coeffs: Vec<F>,
}
impl<F: NttFriendlyFieldElement> ClientMemory<F> {
pub(crate) fn new(dimension: usize) -> Result<Self, VdafError> {
let mut rng = rng();
let n = (dimension + 1).next_power_of_two();
if let Ok(size) = F::Integer::try_from(2 * n) {
if size > F::generator_order() {
return Err(VdafError::Uncategorized(
"input size exceeds field capacity".into(),
));
}
} else {
return Err(VdafError::Uncategorized(
"input size exceeds field capacity".into(),
));
}
Ok(Self {
prng: Prng::from_prio2_seed(&rng.random()),
points_f: vec![F::zero(); n],
points_g: vec![F::zero(); n],
evals_f: vec![F::zero(); 2 * n],
evals_g: vec![F::zero(); 2 * n],
roots_2n: ntt_get_roots(2 * n, false),
roots_n_inverted: ntt_get_roots(n, true),
ntt_memory: PolyNttTempMemory::new(2 * n),
coeffs: vec![F::zero(); 2 * n],
})
}
}
impl<F: NttFriendlyFieldElement> ClientMemory<F> {
pub(crate) fn prove_with<G>(&mut self, dimension: usize, init_function: G) -> Vec<F>
where
G: FnOnce(&mut [F]),
{
let mut proof = vec![F::zero(); proof_length(dimension)];
let unpacked = unpack_proof_mut(&mut proof, dimension).unwrap();
init_function(unpacked.data);
construct_proof(
unpacked.data,
dimension,
unpacked.f0,
unpacked.g0,
unpacked.h0,
unpacked.points_h_packed,
self,
);
proof
}
}
pub(crate) fn proof_length(dimension: usize) -> usize {
dimension + 3 + (dimension + 1).next_power_of_two()
}
#[derive(Debug)]
pub(crate) struct UnpackedProof<'a, F: NttFriendlyFieldElement> {
pub data: &'a [F],
pub f0: &'a F,
pub g0: &'a F,
pub h0: &'a F,
pub points_h_packed: &'a [F],
}
#[derive(Debug)]
pub(crate) struct UnpackedProofMut<'a, F: NttFriendlyFieldElement> {
pub data: &'a mut [F],
pub f0: &'a mut F,
pub g0: &'a mut F,
pub h0: &'a mut F,
pub points_h_packed: &'a mut [F],
}
pub(crate) fn unpack_proof<F: NttFriendlyFieldElement>(
proof: &[F],
dimension: usize,
) -> Result<UnpackedProof<F>, SerializeError> {
if proof.len() != proof_length(dimension) {
return Err(SerializeError::UnpackInputSizeMismatch);
}
let (data, rest) = proof.split_at(dimension);
if let ([f0, g0, h0], points_h_packed) = rest.split_at(3) {
Ok(UnpackedProof {
data,
f0,
g0,
h0,
points_h_packed,
})
} else {
Err(SerializeError::UnpackInputSizeMismatch)
}
}
pub(crate) fn unpack_proof_mut<F: NttFriendlyFieldElement>(
proof: &mut [F],
dimension: usize,
) -> Result<UnpackedProofMut<F>, SerializeError> {
if proof.len() != proof_length(dimension) {
return Err(SerializeError::UnpackInputSizeMismatch);
}
let (data, rest) = proof.split_at_mut(dimension);
if let ([f0, g0, h0], points_h_packed) = rest.split_at_mut(3) {
Ok(UnpackedProofMut {
data,
f0,
g0,
h0,
points_h_packed,
})
} else {
Err(SerializeError::UnpackInputSizeMismatch)
}
}
fn interpolate_and_evaluate_at_2n<F: NttFriendlyFieldElement>(
n: usize,
points_in: &[F],
evals_out: &mut [F],
roots_n_inverted: &[F],
roots_2n: &[F],
ntt_memory: &mut PolyNttTempMemory<F>,
coeffs: &mut [F],
) {
poly_ntt(coeffs, points_in, roots_n_inverted, n, true, ntt_memory);
poly_ntt(evals_out, coeffs, roots_2n, 2 * n, false, ntt_memory);
}
fn construct_proof<F: NttFriendlyFieldElement>(
data: &[F],
dimension: usize,
f0: &mut F,
g0: &mut F,
h0: &mut F,
points_h_packed: &mut [F],
mem: &mut ClientMemory<F>,
) {
let n = (dimension + 1).next_power_of_two();
*f0 = mem.prng.get();
*g0 = mem.prng.get();
mem.points_f[0] = *f0;
mem.points_g[0] = *g0;
*h0 = *f0 * *g0;
for ((f_coeff, g_coeff), data_val) in mem.points_f[1..1 + dimension]
.iter_mut()
.zip(mem.points_g[1..1 + dimension].iter_mut())
.zip(data[..dimension].iter())
{
*f_coeff = *data_val;
*g_coeff = *data_val - F::one();
}
interpolate_and_evaluate_at_2n(
n,
&mem.points_f,
&mut mem.evals_f,
&mem.roots_n_inverted,
&mem.roots_2n,
&mut mem.ntt_memory,
&mut mem.coeffs,
);
interpolate_and_evaluate_at_2n(
n,
&mem.points_g,
&mut mem.evals_g,
&mem.roots_n_inverted,
&mem.roots_2n,
&mut mem.ntt_memory,
&mut mem.coeffs,
);
let mut j: usize = 0;
let mut i: usize = 1;
while i < 2 * n {
points_h_packed[j] = mem.evals_f[i] * mem.evals_g[i];
j += 1;
i += 2;
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use crate::{
field::{Field64, FieldPrio2},
vdaf::prio2::client::{proof_length, unpack_proof, unpack_proof_mut, SerializeError},
};
#[test]
fn test_unpack_share_mut() {
let dim = 15;
let len = proof_length(dim);
let mut share = vec![FieldPrio2::from(0); len];
let unpacked = unpack_proof_mut(&mut share, dim).unwrap();
*unpacked.f0 = FieldPrio2::from(12);
assert_eq!(share[dim], 12);
let mut short_share = vec![FieldPrio2::from(0); len - 1];
assert_matches!(
unpack_proof_mut(&mut short_share, dim),
Err(SerializeError::UnpackInputSizeMismatch)
);
}
#[test]
fn test_unpack_share() {
let dim = 15;
let len = proof_length(dim);
let share = vec![Field64::from(0); len];
unpack_proof(&share, dim).unwrap();
let short_share = vec![Field64::from(0); len - 1];
assert_matches!(
unpack_proof(&short_share, dim),
Err(SerializeError::UnpackInputSizeMismatch)
);
}
}