use crate::core_crypto::algorithms::slice_algorithms::*;
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
use crate::core_crypto::commons::math::decomposition::{
SignedDecomposer, SignedDecomposerNonNative,
};
use crate::core_crypto::commons::parameters::{
DecompositionBaseLog, DecompositionLevelCount, ThreadCount,
};
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use rayon::prelude::*;
pub fn keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
input_lwe_ciphertext: &LweCiphertext<InputCont>,
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
) where
Scalar: UnsignedInteger,
KSKCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
if ciphertext_modulus.is_compatible_with_native_modulus() {
keyswitch_lwe_ciphertext_native_mod_compatible(
lwe_keyswitch_key,
input_lwe_ciphertext,
output_lwe_ciphertext,
)
} else {
keyswitch_lwe_ciphertext_other_mod(
lwe_keyswitch_key,
input_lwe_ciphertext,
output_lwe_ciphertext,
)
}
}
pub fn keyswitch_lwe_ciphertext_native_mod_compatible<Scalar, KSKCont, InputCont, OutputCont>(
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
input_lwe_ciphertext: &LweCiphertext<InputCont>,
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
) where
Scalar: UnsignedInteger,
KSKCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
assert!(
lwe_keyswitch_key.input_key_lwe_dimension()
== input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched input LweDimension. \
LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.input_key_lwe_dimension(),
input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
assert!(
lwe_keyswitch_key.output_key_lwe_dimension()
== output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched output LweDimension. \
LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.output_key_lwe_dimension(),
output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus();
assert_eq!(
lwe_keyswitch_key.ciphertext_modulus(),
output_ciphertext_modulus,
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
lwe_keyswitch_key.ciphertext_modulus(),
output_ciphertext_modulus
);
assert!(
output_ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports power of 2 moduli"
);
let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus();
assert!(
input_ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports power of 2 moduli"
);
output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
*output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
if output_ciphertext_modulus != input_ciphertext_modulus
&& !output_ciphertext_modulus.is_native_modulus()
{
let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize;
let output_decomposer = SignedDecomposer::new(
DecompositionBaseLog(modulus_bits),
DecompositionLevelCount(1),
);
*output_lwe_ciphertext.get_mut_body().data =
output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data);
}
let decomposer = SignedDecomposer::new(
lwe_keyswitch_key.decomposition_base_log(),
lwe_keyswitch_key.decomposition_level_count(),
);
for (keyswitch_key_block, &input_mask_element) in lwe_keyswitch_key
.iter()
.zip(input_lwe_ciphertext.get_mask().as_ref())
{
let decomposition_iter = decomposer.decompose(input_mask_element);
for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign(
output_lwe_ciphertext.as_mut(),
level_key_ciphertext.as_ref(),
decomposed.value(),
);
}
}
}
pub fn keyswitch_lwe_ciphertext_other_mod<Scalar, KSKCont, InputCont, OutputCont>(
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
input_lwe_ciphertext: &LweCiphertext<InputCont>,
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
) where
Scalar: UnsignedInteger,
KSKCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
assert!(
lwe_keyswitch_key.input_key_lwe_dimension()
== input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched input LweDimension. \
LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.input_key_lwe_dimension(),
input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
assert!(
lwe_keyswitch_key.output_key_lwe_dimension()
== output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched output LweDimension. \
LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.output_key_lwe_dimension(),
output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
assert_eq!(
lwe_keyswitch_key.ciphertext_modulus(),
output_lwe_ciphertext.ciphertext_modulus(),
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
lwe_keyswitch_key.ciphertext_modulus(),
output_lwe_ciphertext.ciphertext_modulus()
);
assert_eq!(
lwe_keyswitch_key.ciphertext_modulus(),
input_lwe_ciphertext.ciphertext_modulus(),
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, input LweCiphertext CiphertextModulus {:?}.",
lwe_keyswitch_key.ciphertext_modulus(),
input_lwe_ciphertext.ciphertext_modulus()
);
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
assert!(
!ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports non power of 2 moduli"
);
output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
*output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
let decomposer = SignedDecomposerNonNative::new(
lwe_keyswitch_key.decomposition_base_log(),
lwe_keyswitch_key.decomposition_level_count(),
ciphertext_modulus,
);
for (keyswitch_key_block, &input_mask_element) in lwe_keyswitch_key
.iter()
.zip(input_lwe_ciphertext.get_mask().as_ref())
{
let decomposition_iter = decomposer.decompose(input_mask_element);
for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign_custom_modulus(
output_lwe_ciphertext.as_mut(),
level_key_ciphertext.as_ref(),
decomposed.modular_value(),
ciphertext_modulus.get_custom_modulus().cast_into(),
);
}
}
}
pub fn keyswitch_lwe_ciphertext_with_scalar_change<
InputScalar,
OutputScalar,
KSKCont,
InputCont,
OutputCont,
>(
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
input_lwe_ciphertext: &LweCiphertext<InputCont>,
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
) where
InputScalar: UnsignedInteger,
OutputScalar: UnsignedInteger + CastFrom<InputScalar>,
KSKCont: Container<Element = OutputScalar>,
InputCont: Container<Element = InputScalar>,
OutputCont: ContainerMut<Element = OutputScalar>,
{
assert!(
InputScalar::BITS > OutputScalar::BITS,
"This operation only supports going from a large InputScalar type \
to a strictly smaller OutputScalar type."
);
assert!(
lwe_keyswitch_key.decomposition_base_log().0
* lwe_keyswitch_key.decomposition_level_count().0
<= OutputScalar::BITS,
"This operation only supports a DecompositionBaseLog and DecompositionLevelCount product \
smaller than the OutputScalar bit count."
);
assert!(
lwe_keyswitch_key.input_key_lwe_dimension()
== input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched input LweDimension. \
LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.input_key_lwe_dimension(),
input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
assert!(
lwe_keyswitch_key.output_key_lwe_dimension()
== output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched output LweDimension. \
LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.output_key_lwe_dimension(),
output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus();
assert_eq!(
lwe_keyswitch_key.ciphertext_modulus(),
output_ciphertext_modulus,
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
lwe_keyswitch_key.ciphertext_modulus(),
output_ciphertext_modulus
);
assert!(
output_ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports power of 2 moduli"
);
let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus();
assert!(
input_ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports power of 2 moduli"
);
output_lwe_ciphertext.as_mut().fill(OutputScalar::ZERO);
let output_modulus_bits = match output_ciphertext_modulus.kind() {
CiphertextModulusKind::Native => OutputScalar::BITS,
CiphertextModulusKind::NonNativePowerOfTwo => {
output_ciphertext_modulus.get_custom_modulus().ilog2() as usize
}
CiphertextModulusKind::Other => unreachable!(),
};
let input_body_decomposer = SignedDecomposer::new(
DecompositionBaseLog(output_modulus_bits),
DecompositionLevelCount(1),
);
let input_to_output_scaling_factor = InputScalar::BITS - OutputScalar::BITS;
let rounded_downscaled_body = input_body_decomposer
.closest_representable(*input_lwe_ciphertext.get_body().data)
>> input_to_output_scaling_factor;
*output_lwe_ciphertext.get_mut_body().data = rounded_downscaled_body.cast_into();
let input_decomposer = SignedDecomposer::<InputScalar>::new(
lwe_keyswitch_key.decomposition_base_log(),
lwe_keyswitch_key.decomposition_level_count(),
);
for (keyswitch_key_block, &input_mask_element) in lwe_keyswitch_key
.iter()
.zip(input_lwe_ciphertext.get_mask().as_ref())
{
let decomposition_iter = input_decomposer.decompose(input_mask_element);
for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign(
output_lwe_ciphertext.as_mut(),
level_key_ciphertext.as_ref(),
decomposed.value().cast_into(),
);
}
}
}
pub fn par_keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
input_lwe_ciphertext: &LweCiphertext<InputCont>,
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
) where
Scalar: UnsignedInteger + Send + Sync,
KSKCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
let thread_count = ThreadCount(rayon::current_num_threads());
par_keyswitch_lwe_ciphertext_with_thread_count(
lwe_keyswitch_key,
input_lwe_ciphertext,
output_lwe_ciphertext,
thread_count,
);
}
pub fn par_keyswitch_lwe_ciphertext_with_thread_count<Scalar, KSKCont, InputCont, OutputCont>(
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
input_lwe_ciphertext: &LweCiphertext<InputCont>,
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
thread_count: ThreadCount,
) where
Scalar: UnsignedInteger + Send + Sync,
KSKCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
if ciphertext_modulus.is_compatible_with_native_modulus() {
par_keyswitch_lwe_ciphertext_with_thread_count_native_mod_compatible(
lwe_keyswitch_key,
input_lwe_ciphertext,
output_lwe_ciphertext,
thread_count,
)
} else {
par_keyswitch_lwe_ciphertext_with_thread_count_other_mod(
lwe_keyswitch_key,
input_lwe_ciphertext,
output_lwe_ciphertext,
thread_count,
)
}
}
pub fn par_keyswitch_lwe_ciphertext_with_thread_count_native_mod_compatible<
Scalar,
KSKCont,
InputCont,
OutputCont,
>(
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
input_lwe_ciphertext: &LweCiphertext<InputCont>,
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
thread_count: ThreadCount,
) where
Scalar: UnsignedInteger + Send + Sync,
KSKCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
assert!(
lwe_keyswitch_key.input_key_lwe_dimension()
== input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched input LweDimension. \
LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.input_key_lwe_dimension(),
input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
assert!(
lwe_keyswitch_key.output_key_lwe_dimension()
== output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched output LweDimension. \
LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.output_key_lwe_dimension(),
output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus();
assert_eq!(
lwe_keyswitch_key.ciphertext_modulus(),
output_ciphertext_modulus,
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
lwe_keyswitch_key.ciphertext_modulus(),
output_ciphertext_modulus
);
assert!(
output_ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports power of 2 moduli"
);
let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus();
assert!(
input_ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports power of 2 moduli"
);
assert!(
thread_count.0 != 0,
"Got thread_count == 0, this is not supported"
);
output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
let output_lwe_size = output_lwe_ciphertext.lwe_size();
*output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
if output_ciphertext_modulus != input_ciphertext_modulus
&& !output_ciphertext_modulus.is_native_modulus()
{
let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize;
let output_decomposer = SignedDecomposer::new(
DecompositionBaseLog(modulus_bits),
DecompositionLevelCount(1),
);
*output_lwe_ciphertext.get_mut_body().data =
output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data);
}
let decomposer = SignedDecomposer::new(
lwe_keyswitch_key.decomposition_base_log(),
lwe_keyswitch_key.decomposition_level_count(),
);
let thread_count = thread_count.0.min(rayon::current_num_threads());
let mut intermediate_accumulators = Vec::with_capacity(thread_count);
let chunk_size = input_lwe_ciphertext.lwe_size().0.div_ceil(thread_count);
lwe_keyswitch_key
.par_chunks(chunk_size)
.zip(
input_lwe_ciphertext
.get_mask()
.as_ref()
.par_chunks(chunk_size),
)
.map(|(keyswitch_key_block_chunk, input_mask_element_chunk)| {
let mut buffer =
LweCiphertext::new(Scalar::ZERO, output_lwe_size, output_ciphertext_modulus);
for (keyswitch_key_block, &input_mask_element) in keyswitch_key_block_chunk
.iter()
.zip(input_mask_element_chunk.iter())
{
let decomposition_iter = decomposer.decompose(input_mask_element);
for (level_key_ciphertext, decomposed) in
keyswitch_key_block.iter().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign(
buffer.as_mut(),
level_key_ciphertext.as_ref(),
decomposed.value(),
);
}
}
buffer
})
.collect_into_vec(&mut intermediate_accumulators);
let reduced = intermediate_accumulators
.par_iter_mut()
.reduce_with(|lhs, rhs| {
lhs.as_mut()
.iter_mut()
.zip(rhs.as_ref().iter())
.for_each(|(dst, &src)| *dst = (*dst).wrapping_add(src));
lhs
})
.unwrap();
output_lwe_ciphertext
.get_mut_mask()
.as_mut()
.copy_from_slice(reduced.get_mask().as_ref());
let reduced_ksed_body = *reduced.get_body().data;
*output_lwe_ciphertext.get_mut_body().data =
(*output_lwe_ciphertext.get_mut_body().data).wrapping_add(reduced_ksed_body);
}
pub fn par_keyswitch_lwe_ciphertext_with_thread_count_other_mod<
Scalar,
KSKCont,
InputCont,
OutputCont,
>(
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
input_lwe_ciphertext: &LweCiphertext<InputCont>,
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
thread_count: ThreadCount,
) where
Scalar: UnsignedInteger + Send + Sync,
KSKCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
assert!(
lwe_keyswitch_key.input_key_lwe_dimension()
== input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched input LweDimension. \
LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.input_key_lwe_dimension(),
input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
assert!(
lwe_keyswitch_key.output_key_lwe_dimension()
== output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched output LweDimension. \
LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.output_key_lwe_dimension(),
output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
assert_eq!(
lwe_keyswitch_key.ciphertext_modulus(),
output_lwe_ciphertext.ciphertext_modulus(),
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
lwe_keyswitch_key.ciphertext_modulus(),
output_lwe_ciphertext.ciphertext_modulus()
);
assert_eq!(
lwe_keyswitch_key.ciphertext_modulus(),
input_lwe_ciphertext.ciphertext_modulus(),
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, input LweCiphertext CiphertextModulus {:?}.",
lwe_keyswitch_key.ciphertext_modulus(),
input_lwe_ciphertext.ciphertext_modulus()
);
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
assert!(
!ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports non power of 2 moduli"
);
let ciphertext_modulus_as_scalar: Scalar = ciphertext_modulus.get_custom_modulus().cast_into();
assert!(
thread_count.0 != 0,
"Got thread_count == 0, this is not supported"
);
output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
let output_lwe_size = output_lwe_ciphertext.lwe_size();
*output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
let decomposer = SignedDecomposerNonNative::new(
lwe_keyswitch_key.decomposition_base_log(),
lwe_keyswitch_key.decomposition_level_count(),
ciphertext_modulus,
);
let thread_count = thread_count.0.min(rayon::current_num_threads());
let mut intermediate_accumulators = Vec::with_capacity(thread_count);
let chunk_size = input_lwe_ciphertext.lwe_size().0.div_ceil(thread_count);
lwe_keyswitch_key
.par_chunks(chunk_size)
.zip(
input_lwe_ciphertext
.get_mask()
.as_ref()
.par_chunks(chunk_size),
)
.map(|(keyswitch_key_block_chunk, input_mask_element_chunk)| {
let mut buffer = LweCiphertext::new(Scalar::ZERO, output_lwe_size, ciphertext_modulus);
for (keyswitch_key_block, &input_mask_element) in keyswitch_key_block_chunk
.iter()
.zip(input_mask_element_chunk.iter())
{
let decomposition_iter = decomposer.decompose(input_mask_element);
for (level_key_ciphertext, decomposed) in
keyswitch_key_block.iter().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign_custom_modulus(
buffer.as_mut(),
level_key_ciphertext.as_ref(),
decomposed.modular_value(),
ciphertext_modulus_as_scalar,
);
}
}
buffer
})
.collect_into_vec(&mut intermediate_accumulators);
let reduced = intermediate_accumulators
.par_iter_mut()
.reduce_with(|lhs, rhs| {
lhs.as_mut()
.iter_mut()
.zip(rhs.as_ref().iter())
.for_each(|(dst, &src)| {
*dst = (*dst).wrapping_add_custom_mod(src, ciphertext_modulus_as_scalar)
});
lhs
})
.unwrap();
output_lwe_ciphertext
.get_mut_mask()
.as_mut()
.copy_from_slice(reduced.get_mask().as_ref());
let reduced_ksed_body = *reduced.get_body().data;
*output_lwe_ciphertext.get_mut_body().data = (*output_lwe_ciphertext.get_mut_body().data)
.wrapping_add_custom_mod(reduced_ksed_body, ciphertext_modulus_as_scalar);
}
use crate::core_crypto::commons::noise_formulas::noise_simulation::traits::{
AllocateLweKeyswitchResult, LweKeyswitch,
};
use crate::core_crypto::fft_impl::fft64::math::fft::id;
use std::any::TypeId;
impl<Scalar: UnsignedInteger, KeyCont: Container<Element = Scalar>> AllocateLweKeyswitchResult
for LweKeyswitchKey<KeyCont>
{
type Output = LweCiphertextOwned<Scalar>;
type SideResources = ();
fn allocate_lwe_keyswitch_result(
&self,
_side_resources: &mut Self::SideResources,
) -> Self::Output {
Self::Output::new(
Scalar::ZERO,
self.output_lwe_size(),
self.ciphertext_modulus(),
)
}
}
impl<
InputScalar: UnsignedInteger,
OutputScalar: UnsignedInteger + CastFrom<InputScalar>,
KeyCont: Container<Element = OutputScalar>,
InputCont: Container<Element = InputScalar>,
OutputCont: ContainerMut<Element = OutputScalar>,
> LweKeyswitch<LweCiphertext<InputCont>, LweCiphertext<OutputCont>>
for LweKeyswitchKey<KeyCont>
{
type SideResources = ();
fn lwe_keyswitch(
&self,
input: &LweCiphertext<InputCont>,
output: &mut LweCiphertext<OutputCont>,
_side_resources: &mut Self::SideResources,
) {
if TypeId::of::<InputScalar>() == TypeId::of::<OutputScalar>() {
let input_content = input.as_ref();
let input_as_output_scalar = LweCiphertext::from_container(
id(input_content),
input.ciphertext_modulus().try_to().unwrap(),
);
keyswitch_lwe_ciphertext(self, &input_as_output_scalar, output);
} else {
keyswitch_lwe_ciphertext_with_scalar_change(self, input, output);
}
}
}