use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
use crate::core_crypto::commons::plan::GenericPlanMap;
use crate::core_crypto::commons::utils::izip_eq;
use crate::core_crypto::prelude::*;
use std::sync::{Arc, OnceLock};
use tfhe_ntt::prime64::Plan;
#[derive(Clone, Debug)]
pub struct Ntt64 {
plan: Arc<Plan>,
}
#[derive(Clone, Copy, Debug)]
pub struct Ntt64View<'a> {
pub(crate) plan: &'a Plan,
}
impl Ntt64 {
#[inline]
pub fn as_view(&self) -> Ntt64View<'_> {
Ntt64View { plan: &self.plan }
}
}
type PlanMap = GenericPlanMap<(PolynomialSize, CiphertextModulus<u64>), Plan>;
pub(crate) static PLANS: OnceLock<PlanMap> = OnceLock::new();
fn plans() -> &'static PlanMap {
PLANS.get_or_init(GenericPlanMap::new)
}
impl Ntt64 {
pub fn new(modulus: CiphertextModulus<u64>, polynomial_size: PolynomialSize) -> Self {
let global_plans = plans();
assert_eq!(modulus.kind(), CiphertextModulusKind::Other);
let plan =
global_plans.get_or_init((polynomial_size, modulus), |(polynomial_size, modulus)| {
Plan::try_new(polynomial_size.0, modulus.get_custom_modulus() as u64)
.unwrap_or_else(|| {
panic!(
"could not generate an NTT plan for the given (size, modulus) ({}, {modulus})",
polynomial_size.0
)
})
});
Self { plan }
}
}
impl Ntt64View<'_> {
pub fn polynomial_size(self) -> PolynomialSize {
PolynomialSize(self.plan.ntt_size())
}
pub fn custom_modulus(self) -> u64 {
self.plan.modulus()
}
pub fn forward(self, ntt: PolynomialMutView<'_, u64>, standard: PolynomialView<'_, u64>) {
let mut ntt = ntt;
let ntt = ntt.as_mut();
let standard = standard.as_ref();
ntt.copy_from_slice(standard);
self.plan.fwd(ntt);
}
pub fn forward_normalized(
self,
ntt: PolynomialMutView<'_, u64>,
standard: PolynomialView<'_, u64>,
) {
let mut ntt = ntt;
let ntt = ntt.as_mut();
let standard = standard.as_ref();
ntt.copy_from_slice(standard);
self.plan.fwd(ntt);
self.plan.normalize(ntt);
}
pub fn add_backward(
self,
standard: PolynomialMutView<'_, u64>,
ntt: PolynomialMutView<'_, u64>,
) {
let mut ntt = ntt;
let mut standard = standard;
let ntt = ntt.as_mut();
let standard = standard.as_mut();
self.plan.inv(ntt);
pulp::Arch::new().dispatch(
#[inline(always)]
|| {
for (out, inp) in izip_eq!(standard, &*ntt) {
*out = u64::wrapping_add_custom_mod(*out, *inp, self.custom_modulus());
}
},
)
}
}
impl Ntt64View<'_> {
pub(crate) fn modswitch_requirement(self, from: CiphertextModulus<u64>) -> Option<u32> {
let ntt_modulus = CiphertextModulus::new(self.plan.modulus() as u128);
if from == ntt_modulus {
None
} else {
assert!(
from.is_compatible_with_native_modulus(),
"Only support implicit modswitch from pow-of-two modulus to ntt_modulus"
);
if from.is_native_modulus() {
Some(u64::BITS)
} else {
let pow2_modulus = from.get_custom_modulus();
let pow2_width = pow2_modulus.ilog2();
Some(pow2_width)
}
}
}
pub(crate) fn modswitch_from_power_of_two_to_ntt_prime(
self,
input_modulus_width: u32,
data: &mut [u64],
) {
let mod_p_u128 = self.plan.modulus() as u128;
for val in data.iter_mut() {
let val_u128: u128 = (*val as u128) >> (u64::BITS - input_modulus_width);
*val = (((val_u128 * mod_p_u128) + (1 << (input_modulus_width - 1)))
>> input_modulus_width) as u64;
}
}
pub(crate) fn modswitch_from_ntt_prime_to_power_of_two(
self,
output_modulus_width: u32,
data: &mut [u64],
) {
let mod_p_u128 = self.plan.modulus() as u128;
for val in data.iter_mut() {
let val_u128: u128 = (*val).cast_into();
*val = (((((val_u128) << output_modulus_width) | ((mod_p_u128) >> 1)) / mod_p_u128)
as u64)
<< (u64::BITS - output_modulus_width);
}
}
pub fn forward_from_power_of_two_modulus(
&self,
input_modulus_width: u32,
ntt: PolynomialMutView<'_, u64>,
standard: PolynomialView<'_, u64>,
) {
let mut ntt = ntt;
let ntt = ntt.as_mut();
let standard = standard.as_ref();
ntt.copy_from_slice(standard);
self.modswitch_from_power_of_two_to_ntt_prime(input_modulus_width, ntt);
self.plan.fwd(ntt);
}
pub fn forward_from_decomp(
&self,
ntt: PolynomialMutView<'_, u64>,
decomp: PolynomialView<'_, u64>,
) {
let mut ntt = ntt;
let ntt = ntt.as_mut();
let decomp = decomp.as_ref();
ntt.copy_from_slice(decomp);
for x in ntt.iter_mut() {
*x = if (*x as i64) < 0 {
x.wrapping_add(self.custom_modulus())
} else {
*x
};
}
self.plan.fwd(ntt);
}
pub fn add_backward_on_power_of_two_modulus(
self,
output_modulus_width: u32,
standard: PolynomialMutView<'_, u64>,
ntt: PolynomialMutView<'_, u64>,
) {
let mut ntt = ntt;
let mut standard = standard;
let ntt = ntt.as_mut();
let standard = standard.as_mut();
self.plan.inv(ntt);
self.modswitch_from_ntt_prime_to_power_of_two(output_modulus_width, ntt);
pulp::Arch::new().dispatch(
#[inline(always)]
|| {
for (out, inp) in izip_eq!(standard, &*ntt) {
*out = u64::wrapping_add(*out, *inp);
}
},
)
}
}