use crate::bfv::{BfvParameters, Ciphertext, Plaintext};
use crate::proto::bfv::SecretKey as SecretKeyProto;
use crate::{Error, Result, SerializationError};
use fhe_math::{
rq::{traits::TryConvertFrom, Poly, Representation},
zq::Modulus,
};
use fhe_traits::{DeserializeParametrized, FheDecrypter, FheEncrypter, FheParametrized, Serialize};
use fhe_util::sample_vec_cbd;
use itertools::Itertools;
use num_bigint::BigUint;
use prost::Message;
use rand::{CryptoRng, Rng, RngCore, SeedableRng};
use rand_chacha::ChaCha8Rng;
use std::sync::Arc;
use zeroize::{Zeroize, Zeroizing};
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct SecretKey {
pub(crate) par: Arc<BfvParameters>,
pub(crate) coeffs: Box<[i64]>,
}
impl Zeroize for SecretKey {
fn zeroize(&mut self) {
self.coeffs.zeroize();
}
}
impl Drop for SecretKey {
fn drop(&mut self) {
self.zeroize();
}
}
impl SecretKey {
pub fn random<R: RngCore + CryptoRng>(par: &Arc<BfvParameters>, rng: &mut R) -> Self {
let s_coefficients = sample_vec_cbd(par.degree(), par.variance, rng).unwrap();
Self::new(s_coefficients, par)
}
pub(crate) fn new(coeffs: Vec<i64>, par: &Arc<BfvParameters>) -> Self {
Self {
par: par.to_owned(),
coeffs: coeffs.into_boxed_slice(),
}
}
pub unsafe fn measure_noise(&self, ct: &Ciphertext) -> Result<usize> {
let plaintext = Zeroizing::new(self.try_decrypt(ct)?);
let m = Zeroizing::new(plaintext.to_poly());
let mut s = Zeroizing::new(Poly::try_convert_from(
self.coeffs.as_ref(),
ct[0].ctx(),
false,
Representation::PowerBasis,
)?);
s.change_representation(Representation::Ntt);
let mut si = s.clone();
let mut c = Zeroizing::new(ct[0].clone());
c.disallow_variable_time_computations();
for i in 1..ct.len() {
let mut cis = Zeroizing::new(ct[i].clone());
cis.disallow_variable_time_computations();
*cis.as_mut() *= si.as_ref();
*c.as_mut() += &cis;
*si.as_mut() *= s.as_ref();
}
*c.as_mut() -= &m;
c.change_representation(Representation::PowerBasis);
let ciphertext_modulus = ct[0].ctx().modulus();
let mut noise = 0usize;
for coeff in Vec::<BigUint>::from(c.as_ref()) {
noise = std::cmp::max(
noise,
std::cmp::min(coeff.bits(), (ciphertext_modulus - &coeff).bits()) as usize,
)
}
Ok(noise)
}
pub(crate) fn encrypt_poly<R: RngCore + CryptoRng>(
&self,
p: &Poly,
rng: &mut R,
) -> Result<Ciphertext> {
assert_eq!(p.representation(), &Representation::Ntt);
let level = self.par.level_of_context(p.ctx())?;
let mut seed = <ChaCha8Rng as SeedableRng>::Seed::default();
rand::rng().fill(&mut seed);
let mut s = Zeroizing::new(Poly::try_convert_from(
self.coeffs.as_ref(),
p.ctx(),
false,
Representation::PowerBasis,
)?);
s.change_representation(Representation::Ntt);
let mut a = Poly::random_from_seed(p.ctx(), Representation::Ntt, seed);
let a_s = Zeroizing::new(&a * s.as_ref());
let mut b = Poly::small(p.ctx(), Representation::Ntt, self.par.variance, rng)
.map_err(Error::MathError)?;
b -= &a_s;
b += p;
unsafe {
a.allow_variable_time_computations();
b.allow_variable_time_computations()
}
Ok(Ciphertext {
par: self.par.clone(),
seed: Some(seed),
c: vec![b, a],
level,
})
}
}
impl From<&SecretKey> for SecretKeyProto {
fn from(sk: &SecretKey) -> Self {
Self {
coeffs: sk.coeffs.to_vec(),
}
}
}
impl Serialize for SecretKey {
fn to_bytes(&self) -> Vec<u8> {
SecretKeyProto::from(self).encode_to_vec()
}
}
impl DeserializeParametrized for SecretKey {
type Error = Error;
fn from_bytes(bytes: &[u8], par: &Arc<Self::Parameters>) -> Result<Self> {
let proto: SecretKeyProto = Message::decode(bytes).map_err(|_| {
Error::SerializationError(SerializationError::ProtobufError {
message: "SecretKey decode".into(),
})
})?;
if proto.coeffs.len() != par.degree() {
return Err(Error::SerializationError(
SerializationError::InvalidFormat {
reason: "SecretKey coeffs length and parameters degree mismatch".into(),
},
));
}
Ok(Self {
par: par.clone(),
coeffs: proto.coeffs.into_boxed_slice(),
})
}
}
impl FheParametrized for SecretKey {
type Parameters = BfvParameters;
}
impl FheEncrypter<Plaintext, Ciphertext> for SecretKey {
type Error = Error;
fn try_encrypt<R: RngCore + CryptoRng>(
&self,
pt: &Plaintext,
rng: &mut R,
) -> Result<Ciphertext> {
assert!(Arc::ptr_eq(&self.par, &pt.par));
let m = Zeroizing::new(pt.to_poly());
self.encrypt_poly(m.as_ref(), rng)
}
}
impl FheDecrypter<Plaintext, Ciphertext> for SecretKey {
type Error = Error;
fn try_decrypt(&self, ct: &Ciphertext) -> Result<Plaintext> {
if !Arc::ptr_eq(&self.par, &ct.par) {
Err(Error::DefaultError(
"Incompatible BFV parameters".to_string(),
))
} else {
let mut s = Zeroizing::new(Poly::try_convert_from(
self.coeffs.as_ref(),
ct[0].ctx(),
false,
Representation::PowerBasis,
)?);
s.change_representation(Representation::Ntt);
let mut si = s.clone();
let mut c = Zeroizing::new(ct[0].clone());
c.disallow_variable_time_computations();
for i in 1..ct.len() {
let mut cis = Zeroizing::new(ct[i].clone());
cis.disallow_variable_time_computations();
*cis.as_mut() *= si.as_ref();
*c.as_mut() += &cis;
if i + 1 < ct.len() {
*si.as_mut() *= s.as_ref();
}
}
c.change_representation(Representation::PowerBasis);
let ctx_lvl = self.par.context_level_at(ct.level).unwrap();
let d = Zeroizing::new(c.scale(&ctx_lvl.cipher_plain_context.scaler)?);
let v = Zeroizing::new(
Vec::<u64>::from(d.as_ref())
.iter_mut()
.map(|vi| *vi + *self.par.plaintext)
.collect_vec(),
);
let mut w = v[..self.par.degree()].to_vec();
let q = Modulus::new(self.par.moduli[0]).map_err(Error::MathError)?;
q.reduce_vec(&mut w);
self.par.plaintext.reduce_vec(&mut w);
let mut poly =
Poly::try_convert_from(&w, ct[0].ctx(), false, Representation::PowerBasis)?;
poly.change_representation(Representation::Ntt);
let pt = Plaintext {
par: self.par.clone(),
value: w.into_boxed_slice(),
encoding: None,
poly_ntt: poly,
level: ct.level,
};
Ok(pt)
}
}
}
#[cfg(test)]
mod tests {
use super::SecretKey;
use crate::bfv::{parameters::BfvParameters, Encoding, Plaintext};
use crate::proto::bfv::SecretKey as SecretKeyProto;
use fhe_traits::{DeserializeParametrized, FheDecrypter, FheEncoder, FheEncrypter, Serialize};
use prost::Message;
use rand::rng;
use std::error::Error;
#[test]
fn keygen() {
let mut rng = rng();
let params = BfvParameters::default_arc(1, 16);
let sk = SecretKey::random(¶ms, &mut rng);
assert_eq!(sk.par, params);
sk.coeffs.iter().for_each(|ci| {
assert!((*ci).abs() <= 2 * sk.par.variance as i64)
})
}
#[test]
fn encrypt_decrypt() -> Result<(), Box<dyn Error>> {
let mut rng = rng();
for params in [
BfvParameters::default_arc(1, 16),
BfvParameters::default_arc(6, 16),
] {
for level in 0..params.max_level() {
for _ in 0..20 {
let sk = SecretKey::random(¶ms, &mut rng);
let pt = Plaintext::try_encode(
¶ms.plaintext.random_vec(params.degree(), &mut rng),
Encoding::poly_at_level(level),
¶ms,
)?;
let ct = sk.try_encrypt(&pt, &mut rng)?;
let pt2 = sk.try_decrypt(&ct)?;
println!("Noise: {}", unsafe { sk.measure_noise(&ct)? });
assert_eq!(pt2, pt);
}
}
}
Ok(())
}
#[test]
fn serialize_roundtrip() -> Result<(), Box<dyn Error>> {
let mut rng = rng();
let params = BfvParameters::default_arc(2, 16);
let sk = SecretKey::random(¶ms, &mut rng);
let bytes = sk.to_bytes();
let decoded = SecretKey::from_bytes(&bytes, ¶ms)?;
assert_eq!(decoded, sk);
Ok(())
}
#[test]
fn deserialize_invalid_length() {
let params = BfvParameters::default_arc(1, 16);
let mut proto = SecretKeyProto {
coeffs: vec![0; params.degree()],
};
proto.coeffs.pop();
let bytes = proto.encode_to_vec();
let err = SecretKey::from_bytes(&bytes, ¶ms).unwrap_err();
assert!(matches!(
err,
crate::Error::SerializationError(crate::SerializationError::InvalidFormat { .. })
));
}
}