use crate::{
encrypt::{decrypt_share, EncryptError, PrivateKey},
field::{merge_vector, FieldElement, FieldError},
polynomial::{poly_interpret_eval, PolyAuxMemory},
prng::extract_share_from_seed,
util::{deserialize, proof_length, unpack_proof, vector_with_length, SerializeError},
};
#[derive(Debug, thiserror::Error)]
pub enum ServerError {
#[error("encryption/decryption error")]
Encrypt(#[from] EncryptError),
#[error("finite field operation error")]
Field(#[from] FieldError),
#[error("serialization/deserialization error")]
Serialize(#[from] SerializeError),
}
#[derive(Debug)]
pub struct ValidationMemory<F: FieldElement> {
points_f: Vec<F>,
points_g: Vec<F>,
points_h: Vec<F>,
poly_mem: PolyAuxMemory<F>,
}
impl<F: FieldElement> ValidationMemory<F> {
pub fn new(dimension: usize) -> Self {
let n: usize = (dimension + 1).next_power_of_two();
ValidationMemory {
points_f: vector_with_length(n),
points_g: vector_with_length(n),
points_h: vector_with_length(2 * n),
poly_mem: PolyAuxMemory::new(n),
}
}
}
#[derive(Debug)]
pub struct Server<F: FieldElement> {
dimension: usize,
is_first_server: bool,
accumulator: Vec<F>,
validation_mem: ValidationMemory<F>,
private_key: PrivateKey,
}
impl<F: FieldElement> Server<F> {
pub fn new(dimension: usize, is_first_server: bool, private_key: PrivateKey) -> Server<F> {
Server {
dimension,
is_first_server,
accumulator: vector_with_length(dimension),
validation_mem: ValidationMemory::new(dimension),
private_key,
}
}
fn deserialize_share(&self, encrypted_share: &[u8]) -> Result<Vec<F>, ServerError> {
let share = decrypt_share(encrypted_share, &self.private_key)?;
Ok(if self.is_first_server {
deserialize(&share)?
} else {
let len = proof_length(self.dimension);
extract_share_from_seed(len, &share)
})
}
pub fn generate_verification_message(
&mut self,
eval_at: F,
share: &[u8],
) -> Result<VerificationMessage<F>, ServerError> {
let share_field = self.deserialize_share(share)?;
generate_verification_message(
self.dimension,
eval_at,
&share_field,
self.is_first_server,
&mut self.validation_mem,
)
}
pub fn aggregate(
&mut self,
share: &[u8],
v1: &VerificationMessage<F>,
v2: &VerificationMessage<F>,
) -> Result<bool, ServerError> {
let share_field = self.deserialize_share(share)?;
let is_valid = is_valid_share(v1, v2);
if is_valid {
merge_vector(&mut self.accumulator, &share_field[..self.dimension])?;
}
Ok(is_valid)
}
pub fn total_shares(&self) -> &[F] {
&self.accumulator
}
pub fn merge_total_shares(&mut self, other_total_shares: &[F]) -> Result<(), ServerError> {
Ok(merge_vector(&mut self.accumulator, other_total_shares)?)
}
pub fn choose_eval_at(&self) -> F {
loop {
let eval_at = F::rand();
if !self.validation_mem.poly_mem.roots_2n.contains(&eval_at) {
break eval_at;
}
}
}
}
pub struct VerificationMessage<F: FieldElement> {
pub f_r: F,
pub g_r: F,
pub h_r: F,
}
pub fn generate_verification_message<F: FieldElement>(
dimension: usize,
eval_at: F,
proof: &[F],
is_first_server: bool,
mem: &mut ValidationMemory<F>,
) -> Result<VerificationMessage<F>, ServerError> {
let unpacked = unpack_proof(proof, dimension)?;
let proof_length = 2 * (dimension + 1).next_power_of_two();
mem.points_f[0] = *unpacked.f0;
mem.points_g[0] = *unpacked.g0;
mem.points_h[0] = *unpacked.h0;
for (i, x) in unpacked.data.iter().enumerate() {
mem.points_f[i + 1] = *x;
if is_first_server {
mem.points_g[i + 1] = *x - F::one();
} else {
mem.points_g[i + 1] = *x;
}
}
let mut i = 1;
let mut j = 0;
while i < proof_length {
mem.points_h[i] = unpacked.points_h_packed[j];
j += 1;
i += 2;
}
let f_r = poly_interpret_eval(
&mem.points_f,
&mem.poly_mem.roots_n_inverted,
eval_at,
&mut mem.poly_mem.coeffs,
&mut mem.poly_mem.fft_memory,
);
let g_r = poly_interpret_eval(
&mem.points_g,
&mem.poly_mem.roots_n_inverted,
eval_at,
&mut mem.poly_mem.coeffs,
&mut mem.poly_mem.fft_memory,
);
let h_r = poly_interpret_eval(
&mem.points_h,
&mem.poly_mem.roots_2n_inverted,
eval_at,
&mut mem.poly_mem.coeffs,
&mut mem.poly_mem.fft_memory,
);
Ok(VerificationMessage { f_r, g_r, h_r })
}
pub fn is_valid_share<F: FieldElement>(
v1: &VerificationMessage<F>,
v2: &VerificationMessage<F>,
) -> bool {
let f_r = v1.f_r + v2.f_r;
let g_r = v1.g_r + v2.g_r;
let h_r = v1.h_r + v2.h_r;
f_r * g_r == h_r
}
#[cfg(test)]
mod tests {
use super::*;
use crate::field::Field32;
use crate::util;
#[test]
fn test_validation() {
let dim = 8;
let proof_u32: Vec<u32> = vec![
1, 0, 0, 0, 0, 0, 0, 0, 2052337230, 3217065186, 1886032198, 2533724497, 397524722,
3820138372, 1535223968, 4291254640, 3565670552, 2447741959, 163741941, 335831680,
2567182742, 3542857140, 124017604, 4201373647, 431621210, 1618555683, 267689149,
];
let mut proof: Vec<Field32> = proof_u32.iter().map(|x| Field32::from(*x)).collect();
let share2 = util::tests::secret_share(&mut proof);
let eval_at = Field32::from(12313);
let mut validation_mem = ValidationMemory::new(dim);
let v1 =
generate_verification_message(dim, eval_at, &proof, true, &mut validation_mem).unwrap();
let v2 = generate_verification_message(dim, eval_at, &share2, false, &mut validation_mem)
.unwrap();
assert_eq!(is_valid_share(&v1, &v2), true);
}
}