use std::sync::Arc;
use crate::bfv::{BfvParameters, Ciphertext, PublicKey, SecretKey};
use crate::errors::Result;
use crate::Error;
use fhe_math::rq::{traits::TryConvertFrom, Poly, Representation};
use rand::{CryptoRng, RngCore};
use zeroize::Zeroizing;
use super::{Aggregate, CommonRandomPoly};
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct PublicKeyShare {
pub(crate) par: Arc<BfvParameters>,
pub(crate) crp: CommonRandomPoly,
pub(crate) p0_share: Poly,
}
impl PublicKeyShare {
pub fn new<R: RngCore + CryptoRng>(
sk_share: &SecretKey,
crp: CommonRandomPoly,
rng: &mut R,
) -> Result<Self> {
let par = sk_share.par.clone();
let ctx = par.context_at_level(0)?;
let mut s = Zeroizing::new(Poly::try_convert_from(
sk_share.coeffs.as_ref(),
ctx,
false,
Representation::PowerBasis,
)?);
s.change_representation(Representation::Ntt);
let e = Zeroizing::new(Poly::small(ctx, Representation::Ntt, par.variance, rng)?);
let mut p0_share = -crp.poly.clone();
p0_share.disallow_variable_time_computations();
p0_share.change_representation(Representation::Ntt);
p0_share *= s.as_ref();
p0_share += e.as_ref();
unsafe { p0_share.allow_variable_time_computations() }
Ok(Self { par, crp, p0_share })
}
}
impl Aggregate<PublicKeyShare> for PublicKey {
fn from_shares<T>(iter: T) -> Result<Self>
where
T: IntoIterator<Item = PublicKeyShare>,
{
let mut shares = iter.into_iter();
let share = shares.next().ok_or(Error::TooFewValues {
actual: 0,
minimum: 1,
})?;
let mut p0 = share.p0_share;
for sh in shares {
p0 += &sh.p0_share;
}
Ok(PublicKey {
c: Ciphertext::new(vec![p0, share.crp.poly], &share.par)?,
par: share.par,
})
}
}
#[cfg(test)]
mod tests {
use fhe_traits::{FheEncoder, FheEncrypter};
use rand::rng;
use crate::{
bfv::{BfvParameters, Encoding, Plaintext, PublicKey, SecretKey},
mbfv::{Aggregate as _, CommonRandomPoly},
};
use super::PublicKeyShare;
const NUM_PARTIES: usize = 11;
#[test]
fn protocol_creates_valid_pk() {
let mut rng = rng();
for par in [
BfvParameters::default_arc(1, 16),
BfvParameters::default_arc(6, 32),
] {
for level in 0..=par.max_level() {
for _ in 0..20 {
let crp = CommonRandomPoly::new(&par, &mut rng).unwrap();
let mut pk_shares: Vec<PublicKeyShare> = vec![];
for _ in 0..NUM_PARTIES {
let sk_share = SecretKey::random(&par, &mut rng);
let pk_share =
PublicKeyShare::new(&sk_share, crp.clone(), &mut rng).unwrap();
pk_shares.push(pk_share);
}
let public_key = PublicKey::from_shares(pk_shares).unwrap();
let pt = Plaintext::try_encode(
&par.plaintext.random_vec(par.degree(), &mut rng),
Encoding::poly_at_level(level),
&par,
)
.unwrap();
let _ct = public_key.try_encrypt(&pt, &mut rng).unwrap();
}
}
}
}
}