use std::sync::Arc;
use fhe_math::rq::traits::TryConvertFrom;
use fhe_math::rq::{Poly, Representation};
use rand::{CryptoRng, RngCore};
use zeroize::Zeroizing;
use crate::bfv::{BfvParameters, Ciphertext, PublicKey, SecretKey};
use crate::{Error, Result};
use super::Aggregate;
pub struct PublicKeySwitchShare {
pub(crate) par: Arc<BfvParameters>,
pub(crate) c0: Poly,
pub(crate) h0_share: Poly,
pub(crate) h1_share: Poly,
}
impl PublicKeySwitchShare {
pub fn new<R: RngCore + CryptoRng>(
sk_share: &SecretKey,
public_key: &PublicKey,
ct: &Ciphertext,
rng: &mut R,
) -> Result<Self> {
if sk_share.par != public_key.par || public_key.par != ct.par {
return Err(Error::DefaultError(
"Incompatible BFV parameters".to_string(),
));
}
let par = sk_share.par.clone();
let mut pk_ct = public_key.c.clone();
while pk_ct.level != ct.level {
pk_ct.switch_down()?;
}
let ctx = par.context_at_level(ct.level)?;
let mut s = Zeroizing::new(Poly::try_convert_from(
sk_share.coeffs.as_ref(),
ctx,
false,
Representation::PowerBasis,
)?);
s.change_representation(Representation::Ntt);
s.disallow_variable_time_computations();
let u = Zeroizing::new(Poly::small(ctx, Representation::Ntt, par.variance, rng)?);
let e0 = Zeroizing::new(Poly::small(ctx, Representation::Ntt, par.variance, rng)?);
let e1 = Zeroizing::new(Poly::small(ctx, Representation::Ntt, par.variance, rng)?);
let mut h0 = pk_ct[0].clone();
h0.disallow_variable_time_computations();
h0 *= u.as_ref();
*s.as_mut() *= &ct[1];
h0 += s.as_ref();
h0 += e0.as_ref();
let mut h1 = pk_ct[1].clone();
h1.disallow_variable_time_computations();
h1 *= u.as_ref();
h1 += e1.as_ref();
unsafe {
h0.allow_variable_time_computations();
h1.allow_variable_time_computations();
}
Ok(Self {
par,
c0: ct[0].clone(),
h0_share: h0,
h1_share: h1,
})
}
}
impl Aggregate<PublicKeySwitchShare> for Ciphertext {
fn from_shares<T>(iter: T) -> Result<Self>
where
T: IntoIterator<Item = PublicKeySwitchShare>,
{
let mut shares = iter.into_iter();
let share = shares.next().ok_or(Error::TooFewValues {
actual: 0,
minimum: 1,
})?;
let mut h0 = share.h0_share;
let mut h1 = share.h1_share;
for sh in shares {
h0 += &sh.h0_share;
h1 += &sh.h1_share;
}
let c0 = &share.c0 + &h0;
Ciphertext::new(vec![c0, h1], &share.par)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use fhe_traits::{FheDecrypter, FheEncoder, FheEncrypter};
use rand::rng;
use crate::{
bfv::{BfvParameters, Encoding, Plaintext, PublicKey, SecretKey},
mbfv::{AggregateIter, CommonRandomPoly, PublicKeyShare, PublicKeySwitchShare},
};
const NUM_PARTIES: usize = 11;
struct Party {
sk_share: SecretKey,
pk_share: PublicKeyShare,
}
#[test]
fn encrypt_keyswitch_decrypt() {
let mut rng = rng();
for par in [
BfvParameters::default_arc(1, 16),
BfvParameters::default_arc(6, 32),
] {
for level in 0..=par.max_level() {
for _ in 0..20 {
let crp = CommonRandomPoly::new(&par, &mut rng).unwrap();
let mut parties: Vec<Party> = vec![];
for _ in 0..NUM_PARTIES {
let sk_share = SecretKey::random(&par, &mut rng);
let pk_share =
PublicKeyShare::new(&sk_share, crp.clone(), &mut rng).unwrap();
parties.push(Party { sk_share, pk_share })
}
let public_key: PublicKey = parties
.iter()
.map(|p| p.pk_share.clone())
.aggregate()
.unwrap();
let pt1 = Plaintext::try_encode(
&par.plaintext.random_vec(par.degree(), &mut rng),
Encoding::poly_at_level(level),
&par,
)
.unwrap();
let ct1 = Arc::new(public_key.try_encrypt(&pt1, &mut rng).unwrap());
let sk_out = SecretKey::random(&par, &mut rng);
let pk_out = PublicKey::new(&sk_out, &mut rng);
let ct2 = parties
.iter()
.map(|p| PublicKeySwitchShare::new(&p.sk_share, &pk_out, &ct1, &mut rng))
.aggregate()
.unwrap();
let pt2 = sk_out.try_decrypt(&ct2).unwrap();
assert_eq!(pt1, pt2);
}
}
}
}
}