use std::{ops::Deref, sync::Arc};
use sunscreen_tfhe::{
entities::GlweCiphertextFft,
ops::{
bootstrapping::{
circuit_bootstrap_via_trace_and_scheme_switch, rotate_glwe_positive_monomial_negacyclic,
},
ciphertext::sample_extract,
fft_ops::{cmux, glev_cmux, glwe_ggsw_mad, scheme_switch_fft},
keyswitch::lwe_keyswitch::keyswitch_lwe_to_lwe,
},
};
use crate::params::Params;
use super::{
ComputeKey, L1GlevCiphertext, TrivialOne, TrivialZero,
encryption::{
Encryption, L0LweCiphertext, L1GgswCiphertext, L1GlweCiphertext, L1LweCiphertext,
},
};
#[derive(Clone)]
pub struct KeylessEvaluation {
pub params: Params,
#[allow(unused)]
l1glwe_zero: L1GlweCiphertext,
l1glwe_one: L1GlweCiphertext,
}
impl KeylessEvaluation {
pub fn new(params: &Params, enc: &Encryption) -> Self {
let l1glwe_zero = L1GlweCiphertext::trivial_zero(enc);
let l1glwe_one = L1GlweCiphertext::trivial_one(enc);
Self {
params: params.clone(),
l1glwe_zero,
l1glwe_one,
}
}
pub fn not(&self, output: &mut L1GlweCiphertext, input: &L1GlweCiphertext) {
output.0 = (input.0).as_ref() + self.l1glwe_one.0.as_ref();
}
pub fn xor(&self, output: &mut L1GlweCiphertext, a: &L1GlweCiphertext, b: &L1GlweCiphertext) {
output.0 = a.0.as_ref() + b.0.as_ref();
}
pub fn mul_xn(&self, output: &mut L1GlweCiphertext, input: &L1GlweCiphertext, n: usize) {
rotate_glwe_positive_monomial_negacyclic(
&mut output.0,
&input.0,
n,
&self.params.l1_params,
);
}
pub fn cmux(
&self,
output: &mut L1GlweCiphertext,
sel: &L1GgswCiphertext,
a: &L1GlweCiphertext,
b: &L1GlweCiphertext,
) {
cmux(
&mut output.0,
&a.0,
&b.0,
&sel.0,
&self.params.l1_params,
&self.params.cbs_radix,
);
}
pub fn glev_cmux(
&self,
output: &mut L1GlevCiphertext,
sel: &L1GgswCiphertext,
a: &L1GlevCiphertext,
b: &L1GlevCiphertext,
) {
glev_cmux(
&mut output.0,
&a.0,
&b.0,
&sel.0,
&self.params.l1_params,
&self.params.cbs_radix,
);
}
pub fn multiply_glwe_ggsw(
&self,
output: &mut L1GlweCiphertext,
glwe: &L1GlweCiphertext,
ggsw: &L1GgswCiphertext,
) {
output.0.clear();
let mut output_fft = GlweCiphertextFft::new(&self.params.l1_params);
glwe_ggsw_mad(
&mut output_fft,
&glwe.0,
&ggsw.0,
&self.params.l1_params,
&self.params.cbs_radix,
);
output_fft.ifft(&mut output.0, &self.params.l1_params);
}
pub fn sample_extract_l1(
&self,
output: &mut L1LweCiphertext,
input: &L1GlweCiphertext,
idx: usize,
) {
sample_extract(&mut output.0, &input.0, idx, &self.params.l1_params);
}
}
#[derive(Clone)]
pub struct Evaluation {
keyless_eval: KeylessEvaluation,
compute_key: Arc<ComputeKey>,
l1ggsw_zero: L1GgswCiphertext,
l1ggsw_one: L1GgswCiphertext,
}
impl Deref for Evaluation {
type Target = KeylessEvaluation;
fn deref(&self) -> &Self::Target {
&self.keyless_eval
}
}
impl Evaluation {
pub fn new(compute_key: Arc<ComputeKey>, params: &Params, enc: &Encryption) -> Self {
let mk_ggsw = |msg: bool| {
let lwe = if msg {
enc.trivial_lwe_l0_one()
} else {
enc.trivial_lwe_l0_zero()
};
let mut output = enc.allocate_ggsw_l1();
circuit_bootstrap_via_trace_and_scheme_switch(
&mut output.0,
&lwe.0,
&compute_key.bs_key,
&compute_key.auto_key,
&compute_key.ss_key,
¶ms.l0_params,
¶ms.l1_params,
¶ms.pbs_radix,
¶ms.tr_radix,
¶ms.ss_radix,
¶ms.cbs_radix,
params.addend_count,
);
output
};
let l1ggsw_zero = mk_ggsw(false);
let l1ggsw_one = mk_ggsw(true);
Self {
keyless_eval: KeylessEvaluation::new(params, enc),
compute_key,
l1ggsw_zero,
l1ggsw_one,
}
}
pub fn with_default_params(compute_key: Arc<ComputeKey>) -> Self {
let params = Params::default();
let enc = Encryption::default();
Self::new(compute_key, ¶ms, &enc)
}
pub fn circuit_bootstrap(&self, output: &mut L1GgswCiphertext, input: &L0LweCiphertext) {
circuit_bootstrap_via_trace_and_scheme_switch(
&mut output.0,
&input.0,
&self.compute_key.bs_key,
&self.compute_key.auto_key,
&self.compute_key.ss_key,
&self.params.l0_params,
&self.params.l1_params,
&self.params.pbs_radix,
&self.params.tr_radix,
&self.params.ss_radix,
&self.params.cbs_radix,
self.params.addend_count,
);
}
pub fn scheme_switch(&self, output: &mut L1GgswCiphertext, input: &L1GlevCiphertext) {
scheme_switch_fft(
&mut output.0,
&input.0,
&self.compute_key.ss_key,
&self.params.l1_params,
&self.params.cbs_radix,
&self.params.ss_radix,
);
}
pub fn keyswitch_lwe_l1_lwe_l0(&self, output: &mut L0LweCiphertext, input: &L1LweCiphertext) {
keyswitch_lwe_to_lwe(
&mut output.0,
&input.0,
&self.compute_key.ks_key,
&self.params.l1_params.as_lwe_def(),
&self.params.l0_params,
&self.params.ks_radix,
);
}
pub fn l1ggsw_zero(&self) -> &L1GgswCiphertext {
&self.l1ggsw_zero
}
pub fn l1ggsw_one(&self) -> &L1GgswCiphertext {
&self.l1ggsw_one
}
}
#[cfg(test)]
mod tests {
use sunscreen_tfhe::entities::Polynomial;
use crate::{
DEFAULT_128,
test_utils::{get_encryption_128, get_evaluation_128, get_secret_keys_128},
};
#[test]
fn can_circuit_bootstrap() {
let secret = get_secret_keys_128();
let enc = get_encryption_128();
let eval = get_evaluation_128();
let mut ggsw = enc.allocate_ggsw_l1();
let lwe = enc.encrypt_lwe_l0_secret(false, &secret);
eval.circuit_bootstrap(&mut ggsw, &lwe);
assert!(!enc.decrypt_ggsw_l1(&ggsw, &secret));
let lwe = enc.encrypt_lwe_l0_secret(true, &secret);
eval.circuit_bootstrap(&mut ggsw, &lwe);
assert!(enc.decrypt_ggsw_l1(&ggsw, &secret));
}
#[test]
fn can_lwe_keyswitch() {
let secret = get_secret_keys_128();
let enc = get_encryption_128();
let eval = get_evaluation_128();
let mut lwe_0 = enc.allocate_lwe_l0();
let lwe_1 = enc.encrypt_lwe_l1_secret(false, &secret);
eval.keyswitch_lwe_l1_lwe_l0(&mut lwe_0, &lwe_1);
assert!(!enc.decrypt_lwe_l0(&lwe_0, &secret));
let lwe_1 = enc.encrypt_lwe_l1_secret(true, &secret);
eval.keyswitch_lwe_l1_lwe_l0(&mut lwe_0, &lwe_1);
assert!(enc.decrypt_lwe_l0(&lwe_0, &secret));
}
#[test]
fn can_cmux() {
let secret = get_secret_keys_128();
let enc = get_encryption_128();
let eval = get_evaluation_128();
let mut ggsw = enc.allocate_ggsw_l1();
let mut result = enc.allocate_glwe_l1();
let zero = enc.trivial_glwe_l1_zero();
let one = enc.trivial_glwe_l1_one();
let sel = enc.encrypt_lwe_l0_secret(false, &secret);
eval.circuit_bootstrap(&mut ggsw, &sel);
eval.cmux(&mut result, &ggsw, &zero, &one);
assert_eq!(enc.decrypt_glwe_l1(&result, &secret).coeffs()[0], 0);
let sel = enc.encrypt_lwe_l0_secret(true, &secret);
eval.circuit_bootstrap(&mut ggsw, &sel);
eval.cmux(&mut result, &ggsw, &zero, &one);
assert_eq!(enc.decrypt_glwe_l1(&result, &secret).coeffs()[0], 1);
}
#[test]
fn can_sample_extract() {
let secret = get_secret_keys_128();
let enc = get_encryption_128();
let eval = get_evaluation_128();
let mut poly = vec![0; DEFAULT_128.l1_poly_degree().0];
poly[1] = 1;
let poly = Polynomial::new(&poly);
let ct = enc.encrypt_glwe_l1_secret(&poly, &secret);
let mut lwe = enc.allocate_lwe_l1();
eval.sample_extract_l1(&mut lwe, &ct, 1);
assert!(enc.decrypt_lwe_l1(&lwe, &secret));
}
#[test]
fn can_not() {
let secret = get_secret_keys_128();
let enc = get_encryption_128();
let eval = get_evaluation_128();
let mut output = enc.allocate_glwe_l1();
let poly = vec![0; DEFAULT_128.l1_poly_degree().0];
let poly = Polynomial::new(&poly);
let input = enc.encrypt_glwe_l1_secret(&poly, &secret);
eval.not(&mut output, &input);
assert_eq!(enc.decrypt_glwe_l1(&output, &secret).coeffs()[0], 1);
let mut poly = vec![0; DEFAULT_128.l1_poly_degree().0];
poly[0] = 1;
let poly = Polynomial::new(&poly);
let input = enc.encrypt_glwe_l1_secret(&poly, &secret);
eval.not(&mut output, &input);
assert_eq!(enc.decrypt_glwe_l1(&output, &secret).coeffs()[0], 0);
}
#[test]
fn can_xor() {
let secret = get_secret_keys_128();
let enc = get_encryption_128();
let eval = get_evaluation_128();
let mut output = enc.allocate_glwe_l1();
let zero_poly = vec![0; DEFAULT_128.l1_poly_degree().0];
let zero_poly = Polynomial::new(&zero_poly);
let mut one_poly = vec![0; DEFAULT_128.l1_poly_degree().0];
one_poly[0] = 1;
let one_poly = Polynomial::new(&one_poly);
let a = enc.encrypt_glwe_l1_secret(&zero_poly, &secret);
let b = enc.encrypt_glwe_l1_secret(&zero_poly, &secret);
eval.xor(&mut output, &a, &b);
assert_eq!(enc.decrypt_glwe_l1(&output, &secret).coeffs()[0], 0);
let a = enc.encrypt_glwe_l1_secret(&zero_poly, &secret);
let b = enc.encrypt_glwe_l1_secret(&one_poly, &secret);
eval.xor(&mut output, &a, &b);
assert_eq!(enc.decrypt_glwe_l1(&output, &secret).coeffs()[0], 1);
let a = enc.encrypt_glwe_l1_secret(&one_poly, &secret);
let b = enc.encrypt_glwe_l1_secret(&zero_poly, &secret);
eval.xor(&mut output, &a, &b);
assert_eq!(enc.decrypt_glwe_l1(&output, &secret).coeffs()[0], 1);
let a = enc.encrypt_glwe_l1_secret(&one_poly, &secret);
let b = enc.encrypt_glwe_l1_secret(&one_poly, &secret);
eval.xor(&mut output, &a, &b);
assert_eq!(enc.decrypt_glwe_l1(&output, &secret).coeffs()[0], 0);
}
#[test]
fn can_multiply_glwe_ggsw() {
let secret = get_secret_keys_128();
let enc = get_encryption_128();
let eval = get_evaluation_128();
for a in [false, true] {
for b in [false, true] {
let mut poly = Polynomial::new(&vec![0u64; enc.params.l1_poly_degree().0]);
poly.coeffs_mut()[0] = a as u64;
let a_enc = enc.encrypt_glwe_l1_secret(&poly, &secret);
let b_enc = enc.encrypt_ggsw_l1_secret(b, &secret);
let mut output = enc.allocate_glwe_l1();
eval.multiply_glwe_ggsw(&mut output, &a_enc, &b_enc);
let actual = enc.decrypt_glwe_l1(&output, &secret);
assert_eq!(actual.coeffs()[0], (a && b) as u64);
for i in 1..eval.params.l1_poly_degree().0 {
assert_eq!(actual.coeffs()[i], 0);
}
}
}
}
#[test]
fn can_mul_xn() {
let sk = get_secret_keys_128();
let enc = get_encryption_128();
let eval = get_evaluation_128();
let mut msg = Polynomial::<u64>::zero(DEFAULT_128.l1_params.dim.polynomial_degree.0);
msg.coeffs_mut()[0] = 1;
msg.coeffs_mut()[2] = 1;
let ct = enc.encrypt_glwe_l1_secret(&msg, &sk);
let mut output = enc.allocate_glwe_l1();
eval.mul_xn(&mut output, &ct, 5);
let ans = enc.decrypt_glwe_l1(&output, &sk);
for i in 0..DEFAULT_128.l1_poly_degree().0 {
let expected = if i == 5 || i == 7 { 1 } else { 0 };
assert_eq!(ans.coeffs()[i], expected);
}
}
#[test]
fn can_scheme_switch() {
let sk = get_secret_keys_128();
let enc = get_encryption_128();
let eval = get_evaluation_128();
let mut msg = Polynomial::zero(DEFAULT_128.l1_poly_degree().0);
msg.coeffs_mut()[0] = 1;
let glev = enc.encrypt_glev_l1_secret(&msg, &sk);
let mut ggsw = enc.allocate_ggsw_l1();
eval.scheme_switch(&mut ggsw, &glev);
assert_eq!(enc.decrypt_ggsw_l1(&ggsw, &sk), msg.coeffs()[0] == 1);
}
}