use num::Complex;
use serde::{Deserialize, Serialize};
use std::mem::size_of;
use sunscreen_tfhe::OverlaySize;
use sunscreen_tfhe::entities::{
AutomorphismKey, AutomorphismKeyFft, AutomorphismKeyFftRef, AutomorphismKeyRef, BootstrapKey,
BootstrapKeyFft, BootstrapKeyRef, GlweSecretKey, GlweSecretKeyRef, LweKeyswitchKey,
LweKeyswitchKeyRef, LweSecretKey, LweSecretKeyRef, RlwePublicKey, RlwePublicKeyRef,
SchemeSwitchKey, SchemeSwitchKeyFft, SchemeSwitchKeyRef,
};
pub use sunscreen_tfhe::high_level::keygen::Seed;
use sunscreen_tfhe::high_level::{fft, keygen};
use sunscreen_tfhe::ops::automorphisms::generate_automorphism_key;
use sunscreen_tfhe::ops::bootstrapping::generate_scheme_switch_key;
use sunscreen_tfhe::ops::encryption::rlwe_generate_public_key;
use crate::DEFAULT_128;
use crate::params::Params;
use crate::safe_bincode::GetSize;
#[derive(Clone, Serialize, Deserialize)]
pub struct PublicKey {
pub rlwe_1: RlwePublicKey<u64>,
}
impl GetSize for PublicKey {
fn get_size(params: &Params) -> usize {
(RlwePublicKeyRef::<u64>::size(params.l1_params.dim) + 1) * size_of::<u64>()
}
fn check_is_valid(&self, params: &Params) -> crate::Result<()> {
Ok(self.rlwe_1.check_is_valid(params.l1_params.dim)?)
}
}
impl PublicKey {
pub fn generate(params: &Params, sk: &SecretKey) -> Self {
assert_eq!(
params.l1_params.dim.size.0, 1,
"Unfortunately, public keys currently require a GLWE size of 1. This restriction will likely be eased in the future."
);
let mut pk = RlwePublicKey::new(¶ms.l1_params);
rlwe_generate_public_key(&mut pk, &sk.glwe_1, ¶ms.l1_params);
Self { rlwe_1: pk }
}
pub fn generate_with_default_params(sk: &SecretKey) -> Self {
Self::generate(&DEFAULT_128, sk)
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct SecretKey {
pub lwe_0: LweSecretKey<u64>,
pub glwe_1: GlweSecretKey<u64>,
}
impl GetSize for SecretKey {
fn get_size(params: &Params) -> usize {
(LweSecretKeyRef::<u64>::size(params.l0_params.dim)
+ GlweSecretKeyRef::<u64>::size(params.l1_params.dim)
+ 3)
* size_of::<u64>()
}
fn check_is_valid(&self, params: &Params) -> crate::Result<()> {
self.lwe_0.check_is_valid(params.l0_params.dim)?;
self.glwe_1.check_is_valid(params.l1_params.dim)?;
Ok(())
}
}
impl SecretKey {
pub fn generate(params: &Params) -> Self {
let lwe_0 = keygen::generate_binary_lwe_sk(¶ms.l0_params);
let glwe_1 = keygen::generate_binary_glwe_sk(¶ms.l1_params);
Self { lwe_0, glwe_1 }
}
pub fn generate_with_default_params() -> Self {
Self::generate(&Params::default())
}
pub fn generate_with_seed(params: &Params, seed: &Seed) -> Self {
let (lwe_0, glwe_1) = keygen::generate_binary_lwe_glwe_sk_with_seed(
¶ms.l0_params,
¶ms.l1_params,
seed,
);
Self { lwe_0, glwe_1 }
}
pub fn generate_with_default_params_and_seed(seed: &Seed) -> Self {
Self::generate_with_seed(&Params::default(), seed)
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct ComputeKeyNonFft {
pub bs_key: BootstrapKey<u64>,
pub ks_key: LweKeyswitchKey<u64>,
pub auto_key: AutomorphismKey<u64>,
pub ss_key: SchemeSwitchKey<u64>,
}
impl GetSize for ComputeKeyNonFft {
fn get_size(params: &Params) -> usize {
let size = BootstrapKeyRef::<u64>::size((
params.l0_params.dim,
params.l1_params.dim,
params.pbs_radix.count,
params.addend_count,
));
let size = size
+ LweKeyswitchKeyRef::<u64>::size((
params.l1_params.as_lwe_def().dim,
params.l0_params.dim,
params.ks_radix.count,
));
let size =
size + SchemeSwitchKeyRef::<u64>::size((params.l1_params.dim, params.ss_radix.count));
let size =
size + AutomorphismKeyRef::<u64>::size((params.l1_params.dim, params.tr_radix.count));
let size = size + 4;
size * size_of::<u64>()
}
fn check_is_valid(&self, params: &Params) -> crate::Result<()> {
self.bs_key.check_is_valid((
params.l0_params.dim,
params.l1_params.dim,
params.pbs_radix.count,
params.addend_count,
))?;
self.ks_key.check_is_valid((
params.l1_params.as_lwe_def().dim,
params.l0_params.dim,
params.ks_radix.count,
))?;
self.ss_key
.check_is_valid((params.l1_params.dim, params.ss_radix.count))?;
self.auto_key
.check_is_valid((params.l1_params.dim, params.tr_radix.count))?;
Ok(())
}
}
impl ComputeKeyNonFft {
pub fn generate(secret_key: &SecretKey, params: &Params) -> Self {
let bs_key = keygen::generate_bootstrapping_key(
&secret_key.lwe_0,
&secret_key.glwe_1,
¶ms.l0_params,
¶ms.l1_params,
¶ms.pbs_radix,
params.addend_count,
);
let ks_key = keygen::generate_ksk(
secret_key.glwe_1.to_lwe_secret_key(),
&secret_key.lwe_0,
¶ms.l1_params.as_lwe_def(),
¶ms.l0_params,
¶ms.ks_radix,
);
let mut ss_key = SchemeSwitchKey::new(¶ms.l1_params, ¶ms.ss_radix);
generate_scheme_switch_key(
&mut ss_key,
&secret_key.glwe_1,
¶ms.l1_params,
¶ms.ss_radix,
);
let mut auto_key = AutomorphismKey::new(¶ms.l1_params, ¶ms.tr_radix);
generate_automorphism_key(
&mut auto_key,
&secret_key.glwe_1,
¶ms.l1_params,
¶ms.tr_radix,
);
Self {
ks_key,
bs_key,
ss_key,
auto_key,
}
}
pub fn fft(&self, params: &Params) -> ComputeKey {
let mut ssk_fft = SchemeSwitchKeyFft::new(¶ms.l1_params, ¶ms.ss_radix);
self.ss_key
.fft(&mut ssk_fft, ¶ms.l1_params, ¶ms.ss_radix);
let mut auto_key_fft = AutomorphismKeyFft::new(¶ms.l1_params, ¶ms.tr_radix);
self.auto_key
.fft(&mut auto_key_fft, ¶ms.l1_params, ¶ms.tr_radix);
ComputeKey {
bs_key: fft::fft_bootstrap_key(
&self.bs_key,
¶ms.l0_params,
¶ms.l1_params,
¶ms.pbs_radix,
params.addend_count,
),
ks_key: self.ks_key.clone(),
ss_key: ssk_fft,
auto_key: auto_key_fft,
}
}
pub fn generate_with_default_params(secret_key: &SecretKey) -> Self {
let params = Params::default();
Self::generate(secret_key, ¶ms)
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct ComputeKey {
pub bs_key: BootstrapKeyFft<Complex<f64>>,
pub ks_key: LweKeyswitchKey<u64>,
pub ss_key: SchemeSwitchKeyFft<Complex<f64>>,
pub auto_key: AutomorphismKeyFft<Complex<f64>>,
}
impl GetSize for ComputeKey {
fn get_size(params: &Params) -> usize {
let size = BootstrapKeyRef::<u64>::size((
params.l0_params.dim,
params.l1_params.dim,
params.pbs_radix.count,
params.addend_count,
));
let size = size
+ LweKeyswitchKeyRef::<u64>::size((
params.l1_params.as_lwe_def().dim,
params.l0_params.dim,
params.ks_radix.count,
));
let size =
size + SchemeSwitchKeyRef::<u64>::size((params.l1_params.dim, params.ss_radix.count));
let size = size
+ AutomorphismKeyFftRef::<Complex<f64>>::size((
params.l1_params.dim,
params.tr_radix.count,
));
let size = size * size_of::<Complex<f64>>();
size + 4 * size_of::<u64>()
}
fn check_is_valid(&self, params: &Params) -> crate::Result<()> {
self.bs_key.check_is_valid((
params.l0_params.dim,
params.l1_params.dim,
params.pbs_radix.count,
params.addend_count,
))?;
self.ks_key.check_is_valid((
params.l1_params.as_lwe_def().dim,
params.l0_params.dim,
params.ks_radix.count,
))?;
self.ss_key
.check_is_valid((params.l1_params.dim, params.ss_radix.count))?;
self.auto_key
.check_is_valid((params.l1_params.dim, params.tr_radix.count))?;
Ok(())
}
}
impl ComputeKey {
pub fn generate(secret_key: &SecretKey, params: &Params) -> Self {
ComputeKeyNonFft::generate(secret_key, params).fft(params)
}
pub fn generate_with_default_params(secret_key: &SecretKey) -> Self {
let params = Params::default();
Self::generate(secret_key, ¶ms)
}
}
#[cfg(test)]
mod tests {
use crate::{DEFAULT_128, SecretKey, Seed};
fn secret_keys_equal(sk1: &SecretKey, sk2: &SecretKey) -> bool {
let lwe_equal = sk1.lwe_0.s() == sk2.lwe_0.s();
let glwe1_bytes = bincode::serialize(&sk1.glwe_1).expect("Failed to serialize GLWE key");
let glwe2_bytes = bincode::serialize(&sk2.glwe_1).expect("Failed to serialize GLWE key");
let glwe_equal = glwe1_bytes == glwe2_bytes;
lwe_equal && glwe_equal
}
#[test]
fn test_secret_key_seeded_deterministic() {
let seed = Seed::from_bytes([42u8; 32]);
let sk1 = SecretKey::generate_with_seed(&DEFAULT_128, &seed);
let sk2 = SecretKey::generate_with_seed(&DEFAULT_128, &seed);
assert!(
secret_keys_equal(&sk1, &sk2),
"Keys generated from same seed should be identical"
);
}
#[test]
fn test_secret_key_different_seeds() {
let seed1 = Seed::from_bytes([1u8; 32]);
let seed2 = Seed::from_bytes([2u8; 32]);
let sk1 = SecretKey::generate_with_seed(&DEFAULT_128, &seed1);
let sk2 = SecretKey::generate_with_seed(&DEFAULT_128, &seed2);
assert!(!secret_keys_equal(&sk1, &sk2));
}
#[test]
fn test_secret_key_with_default_params_and_seed() {
let seed = Seed::from_bytes([123u8; 32]);
let sk1 = SecretKey::generate_with_default_params_and_seed(&seed);
let sk2 = SecretKey::generate_with_seed(&DEFAULT_128, &seed);
assert!(secret_keys_equal(&sk1, &sk2));
}
}