use crate::crypto::UnsignedTorus;
use crate::math::decomposition::{
DecompositionBaseLog, DecompositionLevel, DecompositionLevelCount,
};
use crate::math::fft::{Complex64, Fft, FourierPolynomial};
use crate::math::polynomial::{MonomialDegree, Polynomial, PolynomialList};
use crate::math::tensor::{AsMutSlice, AsMutTensor, AsRefSlice, AsRefTensor};
use crate::numeric::{CastInto, Numeric};
use crate::{ck_dim_eq, zip, zip_args};
use super::bootstrap::BootstrapKey;
use super::ggsw::GgswCiphertext;
use super::glwe::GlweCiphertext;
use super::lwe::{LweBody, LweCiphertext};
#[cfg(test)]
mod tests;
pub fn external_product<RgswCont, RlweCont, InCont, FftCont1, FftCont2, FftCont3, Scalar>(
fft: &mut Fft,
dec_i_fft: &mut FourierPolynomial<FftCont1>,
tmp_dec_i_fft: &mut FourierPolynomial<FftCont2>,
res_fft: &mut [FourierPolynomial<FftCont3>],
output: &mut GlweCiphertext<InCont>,
ggsw: &GgswCiphertext<RgswCont>,
glwe: &mut GlweCiphertext<RlweCont>,
) where
GlweCiphertext<InCont>: AsMutTensor<Element = Scalar>,
GgswCiphertext<RgswCont>: AsRefTensor<Element = Complex64>,
GlweCiphertext<RlweCont>: AsMutTensor<Element = Scalar>,
FourierPolynomial<FftCont1>: AsMutTensor<Element = Complex64>,
FourierPolynomial<FftCont2>: AsMutTensor<Element = Complex64>,
FourierPolynomial<FftCont3>: AsMutTensor<Element = Complex64>,
Scalar: UnsignedTorus,
{
ck_dim_eq!(glwe.polynomial_size().0 => ggsw.polynomial_size().0);
ck_dim_eq!(output.polynomial_size().0 => ggsw.polynomial_size().0);
ck_dim_eq!(glwe.size().0 => ggsw.glwe_size().0);
ck_dim_eq!(output.size().0 => ggsw.glwe_size().0);
let base_log = ggsw.decomposition_base_log().0;
let level = ggsw.decomposition_level_count().0;
let polynomial_size = glwe.polynomial_size().0;
let dimension = glwe.mask_size().0;
let even_dimension = dimension % 2 == 0;
let zero = <Scalar as Numeric>::ZERO;
let mut carry = vec![zero; polynomial_size * (dimension + 1)];
let mut sign_decomp_0 = vec![zero; polynomial_size];
let mut sign_decomp_1 = vec![zero; polynomial_size];
for value in glwe.as_mut_tensor().as_mut_slice().iter_mut() {
*value = value.round_to_closest_multiple(
DecompositionBaseLog(base_log),
DecompositionLevelCount(level),
);
}
let matrix_size = (dimension + 1) * (dimension + 1) * polynomial_size;
for (j, rgsw_level) in ggsw
.as_tensor()
.as_slice()
.chunks(matrix_size)
.rev()
.enumerate()
{
let dec_level = level - j - 1;
let trlwe_chunks = glwe
.as_tensor()
.as_slice()
.chunks_exact(2 * polynomial_size);
let carry_chunks_mut = carry.chunks_exact_mut(2 * polynomial_size);
let rgsw_level_chunks = rgsw_level.chunks_exact(2 * (polynomial_size * (dimension + 1)));
if even_dimension {
let rlwe_polynomial = trlwe_chunks.remainder();
let carry_polynomial = carry_chunks_mut.into_remainder();
let trgsw_line = rgsw_level_chunks.remainder();
signed_decompose_one_level(
&mut sign_decomp_0,
carry_polynomial,
rlwe_polynomial,
DecompositionBaseLog(base_log),
DecompositionLevel(dec_level),
);
fft.forward_as_integer(
dec_i_fft,
&Polynomial::from_container(sign_decomp_0.as_slice()),
);
for (trgsw_elt, res_fft_polynomial) in
trgsw_line.chunks(polynomial_size).zip(res_fft.iter_mut())
{
res_fft_polynomial.update_with_multiply_accumulate(
&FourierPolynomial::from_container(trgsw_elt),
dec_i_fft,
);
}
}
for zip_args!(
double_rlwe_polynomial,
double_carry_polynomial,
double_trgsw_line
) in zip!(
trlwe_chunks,
carry.chunks_exact_mut(2 * polynomial_size),
rgsw_level_chunks
) {
let (rlwe_polynomial_0, rlwe_polynomial_1) =
double_rlwe_polynomial.split_at(polynomial_size);
let (carry_polynomial_0, carry_polynomial_1) =
double_carry_polynomial.split_at_mut(polynomial_size);
let (trgsw_line_0, trgsw_line_1) =
double_trgsw_line.split_at(polynomial_size * (dimension + 1));
signed_decompose_one_level(
&mut sign_decomp_0,
carry_polynomial_0,
rlwe_polynomial_0,
DecompositionBaseLog(base_log),
DecompositionLevel(dec_level),
);
signed_decompose_one_level(
&mut sign_decomp_1,
carry_polynomial_1,
rlwe_polynomial_1,
DecompositionBaseLog(base_log),
DecompositionLevel(dec_level),
);
fft.forward_two_as_integer(
dec_i_fft,
tmp_dec_i_fft,
&Polynomial::from_container(sign_decomp_0.as_slice()),
&Polynomial::from_container(sign_decomp_1.as_slice()),
);
for zip_args!(trgsw_elt_0, trgsw_elt_1, res_fft_polynomial) in zip!(
trgsw_line_0.chunks(polynomial_size),
trgsw_line_1.chunks(polynomial_size),
res_fft.iter_mut()
) {
res_fft_polynomial.update_with_two_multiply_accumulate(
&FourierPolynomial::from_container(trgsw_elt_0),
dec_i_fft,
&FourierPolynomial::from_container(trgsw_elt_1),
tmp_dec_i_fft,
);
}
}
}
if even_dimension {
let res_remainder = output
.as_mut_tensor()
.as_mut_slice()
.chunks_exact_mut(2 * polynomial_size)
.into_remainder();
let res_fft_remainder = res_fft.chunks_exact_mut(2).into_remainder();
fft.add_backward_as_torus(
&mut Polynomial::from_container(res_remainder),
&mut res_fft_remainder[0],
);
}
for (double_res_polynomial, double_res_fft_polynomial) in output
.as_mut_tensor()
.as_mut_slice()
.chunks_exact_mut(2 * polynomial_size)
.zip(res_fft.chunks_exact_mut(2))
{
let (res_fft_0, res_fft_1) = double_res_fft_polynomial.split_at_mut(1);
let (res_0, res_1) = double_res_polynomial.split_at_mut(polynomial_size);
let mut res_0 = Polynomial::from_container(res_0);
let mut res_1 = Polynomial::from_container(res_1);
fft.add_backward_two_as_torus(&mut res_0, &mut res_1, &mut res_fft_0[0], &mut res_fft_1[0]);
}
}
pub fn cmux<RlweCont0, RlweCont1, RgswCont, FftCont1, FftCont2, FftCont3, Scalar>(
fft: &mut Fft,
dec_i_fft: &mut FourierPolynomial<FftCont1>,
tmp_dec_i_fft: &mut FourierPolynomial<FftCont2>,
res_fft: &mut [FourierPolynomial<FftCont3>],
glwe_0: &mut GlweCiphertext<RlweCont0>,
glwe_1: &mut GlweCiphertext<RlweCont1>,
ggsw: &GgswCiphertext<RgswCont>,
) where
GgswCiphertext<RgswCont>: AsRefTensor<Element = Complex64>,
FourierPolynomial<FftCont1>: AsMutTensor<Element = Complex64>,
FourierPolynomial<FftCont2>: AsMutTensor<Element = Complex64>,
FourierPolynomial<FftCont3>: AsMutTensor<Element = Complex64>,
GlweCiphertext<RlweCont0>: AsMutTensor<Element = Scalar>,
GlweCiphertext<RlweCont1>: AsMutTensor<Element = Scalar>,
Scalar: UnsignedTorus,
{
glwe_1
.as_mut_tensor()
.update_with_wrapping_sub(glwe_0.as_tensor());
external_product(fft, dec_i_fft, tmp_dec_i_fft, res_fft, glwe_0, ggsw, glwe_1);
}
pub fn blind_rotate<OutCont, LweCont, BskCont, FftCont1, FftCont2, FftCont3, Scalar>(
fft: &mut Fft,
dec_i_fft: &mut FourierPolynomial<FftCont1>,
tmp_dec_i_fft: &mut FourierPolynomial<FftCont2>,
res_fft: &mut [FourierPolynomial<FftCont3>],
output: &mut GlweCiphertext<OutCont>,
lwe: &LweCiphertext<LweCont>,
bootstrap_key: &BootstrapKey<BskCont>,
) where
GlweCiphertext<OutCont>: AsMutTensor<Element = Scalar>,
GlweCiphertext<Vec<Scalar>>: AsMutTensor<Element = Scalar>,
LweCiphertext<LweCont>: AsRefTensor<Element = Scalar>,
BootstrapKey<BskCont>: AsRefTensor<Element = Complex64>,
FourierPolynomial<FftCont1>: AsMutTensor<Element = Complex64>,
FourierPolynomial<FftCont2>: AsMutTensor<Element = Complex64>,
FourierPolynomial<FftCont3>: AsMutTensor<Element = Complex64>,
Scalar: UnsignedTorus,
{
let dimension = output.mask_size().0;
let level = bootstrap_key.level_count().0;
let polynomial_size = output.polynomial_size().0;
let (body_lwe, mask_lwe) = lwe.get_body_and_mask();
let n_coefs: f64 = output.polynomial_size().0.cast_into();
let tmp: f64 = body_lwe.0.cast_into() / (<Scalar as Numeric>::MAX.cast_into() + 1.);
let tmp: f64 = tmp * 2. * n_coefs;
let b_hat: usize = tmp.round().cast_into();
output
.as_mut_polynomial_list()
.update_with_wrapping_monic_monomial_div(MonomialDegree(b_hat));
let mut ct_1 = GlweCiphertext::allocate(Scalar::ZERO, output.polynomial_size(), output.size());
let trgsw_size: usize = dimension * (dimension + 1) * level * polynomial_size
+ (dimension + 1) * level * polynomial_size;
for (a, trgsw_i) in mask_lwe
.mask_element_iter()
.zip(bootstrap_key.as_tensor().as_slice().chunks(trgsw_size))
{
ct_1.as_mut_tensor()
.as_mut_slice()
.copy_from_slice(output.as_tensor().as_slice());
let poly_size: f64 = polynomial_size.cast_into();
let tmp: f64 = (*a).cast_into() / (<Scalar as Numeric>::MAX.cast_into() + 1.);
let tmp: f64 = tmp * 2. * poly_size;
let a_hat: usize = tmp.round().cast_into();
if a_hat != 0 {
ct_1.as_mut_polynomial_list()
.update_with_wrapping_monic_monomial_mul(MonomialDegree(a_hat));
for res_fft_polynomial in res_fft.iter_mut() {
for m in res_fft_polynomial.coefficient_iter_mut() {
*m = Complex64::new(0., 0.);
}
}
cmux(
fft,
dec_i_fft,
tmp_dec_i_fft,
res_fft,
output,
&mut ct_1,
&GgswCiphertext::from_container(
trgsw_i,
bootstrap_key.glwe_size(),
bootstrap_key.polynomial_size(),
bootstrap_key.base_log(),
),
);
}
}
}
pub fn constant_sample_extract<LweCont, RlweCont, Scalar>(
lwe: &mut LweCiphertext<LweCont>,
glwe: &GlweCiphertext<RlweCont>,
) where
LweCiphertext<LweCont>: AsMutTensor<Element = Scalar>,
GlweCiphertext<RlweCont>: AsRefTensor<Element = Scalar>,
Scalar: UnsignedTorus,
{
let (body_lwe, mut mask_lwe) = lwe.get_mut_body_and_mask();
let (body_rlwe, mask_rlwe) = glwe.get_body_and_mask();
let polynomial_size = glwe.polynomial_size().0;
for (mask_lwe_polynomial, mask_rlwe_polynomial) in mask_lwe
.as_mut_tensor()
.as_mut_slice()
.chunks_mut(polynomial_size)
.zip(mask_rlwe.as_tensor().as_slice().chunks(polynomial_size))
{
for (lwe_coeff, rlwe_coeff) in mask_lwe_polynomial
.iter_mut()
.zip(mask_rlwe_polynomial.iter().rev())
{
*lwe_coeff = (Scalar::ZERO).wrapping_sub(*rlwe_coeff);
}
}
let mut mask_lwe_poly = PolynomialList::from_container(
mask_lwe.as_mut_tensor().as_mut_slice(),
glwe.polynomial_size(),
);
mask_lwe_poly.update_with_wrapping_monic_monomial_mul(MonomialDegree(1));
*body_lwe = LweBody(*body_rlwe.as_tensor().get_element(0));
}
pub fn bootstrap<OutCont, InCont, BskCont, AccCont, Scalar>(
lwe_out: &mut LweCiphertext<OutCont>,
lwe_in: &LweCiphertext<InCont>,
bootstrap_key: &BootstrapKey<BskCont>,
accumulator: &mut GlweCiphertext<AccCont>,
) where
LweCiphertext<OutCont>: AsMutTensor<Element = Scalar>,
LweCiphertext<InCont>: AsRefTensor<Element = Scalar>,
BootstrapKey<BskCont>: AsMutTensor<Element = Complex64>,
GlweCiphertext<AccCont>: AsMutTensor<Element = Scalar>,
Scalar: UnsignedTorus,
{
let polynomial_size = bootstrap_key.polynomial_size();
let dimension = bootstrap_key.glwe_size().0 - 1;
let mut fft = Fft::new(polynomial_size);
let mut dec_i_fft = FourierPolynomial::allocate(Complex64::new(0., 0.), polynomial_size);
let mut tmp_dec_i_fft = FourierPolynomial::allocate(Complex64::new(0., 0.), polynomial_size);
let mut res_fft =
vec![FourierPolynomial::allocate(Complex64::new(0., 0.), polynomial_size); dimension + 1];
blind_rotate(
&mut fft,
&mut dec_i_fft,
&mut tmp_dec_i_fft,
&mut res_fft,
accumulator,
lwe_in,
bootstrap_key,
);
constant_sample_extract(lwe_out, accumulator);
}
fn signed_decompose_one_level<Scalar>(
sign_decomp: &mut [Scalar],
carries: &mut [Scalar],
polynomial: &[Scalar],
base_log: DecompositionBaseLog,
dec_level: DecompositionLevel,
) where
Scalar: UnsignedTorus,
{
for (carry, (decomp, value)) in carries
.iter_mut()
.zip(sign_decomp.iter_mut().zip(polynomial.iter()))
{
let pair = value.signed_decompose_one_level(*carry, base_log, dec_level);
*decomp = pair.0;
*carry = pair.1;
}
}