use super::atomic_pattern::AtomicPatternServerKey;
use super::backward_compatibility::key_switching_key::{
CompressedKeySwitchingKeyMaterialVersions, CompressedKeySwitchingKeyVersions,
KeySwitchingKeyDestinationAtomicPatternVersions, KeySwitchingKeyMaterialVersions,
KeySwitchingKeyVersions,
};
use super::server_key::{
KS32ServerKeyView, ServerKeyView, ShortintBootstrappingKey, StandardServerKeyView,
};
use super::AtomicPatternKind;
use crate::conformance::ParameterSetConformant;
use crate::core_crypto::prelude::{
keyswitch_lwe_ciphertext, CastFrom, CastInto, Cleartext, LweCiphertext, LweCiphertextOwned,
LweKeyswitchKeyConformanceParams, LweKeyswitchKeyOwned, SeededLweKeyswitchKeyOwned,
UnsignedInteger, UnsignedTorus,
};
use crate::shortint::atomic_pattern::AtomicPattern;
use crate::shortint::ciphertext::{unchecked_create_trivial_with_lwe_size, Degree};
use crate::shortint::client_key::atomic_pattern::EncryptionAtomicPattern;
use crate::shortint::client_key::secret_encryption_key::SecretEncryptionKeyView;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::parameters::{
EncryptionKeyChoice, NoiseLevel, ShortintKeySwitchingParameters,
};
use crate::shortint::server_key::apply_programmable_bootstrap;
use crate::shortint::{Ciphertext, ClientKey, CompressedServerKey, MaxNoiseLevel, ServerKey};
use core::cmp::Ordering;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;
#[cfg(test)]
mod test;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Versionize)]
#[versionize(KeySwitchingKeyDestinationAtomicPatternVersions)]
pub enum KeySwitchingKeyDestinationAtomicPattern {
Standard,
KeySwitch32,
}
impl From<AtomicPatternKind> for KeySwitchingKeyDestinationAtomicPattern {
fn from(value: AtomicPatternKind) -> Self {
match value {
AtomicPatternKind::Standard(_) => Self::Standard,
AtomicPatternKind::KeySwitch32 => Self::KeySwitch32,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Versionize)]
#[versionize(KeySwitchingKeyMaterialVersions)]
pub struct KeySwitchingKeyMaterial {
pub(crate) key_switching_key: LweKeyswitchKeyOwned<u64>,
pub(crate) cast_rshift: i8,
pub(crate) destination_key: EncryptionKeyChoice,
pub(crate) destination_atomic_pattern: KeySwitchingKeyDestinationAtomicPattern,
}
impl KeySwitchingKeyMaterial {
pub fn into_raw_parts(
self,
) -> (
LweKeyswitchKeyOwned<u64>,
i8,
EncryptionKeyChoice,
KeySwitchingKeyDestinationAtomicPattern,
) {
let Self {
key_switching_key,
cast_rshift,
destination_key,
destination_atomic_pattern,
} = self;
(
key_switching_key,
cast_rshift,
destination_key,
destination_atomic_pattern,
)
}
pub fn from_raw_parts(
key_switching_key: LweKeyswitchKeyOwned<u64>,
cast_rshift: i8,
destination_key: EncryptionKeyChoice,
destination_atomic_pattern: KeySwitchingKeyDestinationAtomicPattern,
) -> Self {
Self {
key_switching_key,
cast_rshift,
destination_key,
destination_atomic_pattern,
}
}
pub fn as_view(&self) -> KeySwitchingKeyMaterialView<'_> {
KeySwitchingKeyMaterialView {
key_switching_key: &self.key_switching_key,
cast_rshift: self.cast_rshift,
destination_key: self.destination_key,
destination_atomic_pattern: self.destination_atomic_pattern,
}
}
}
pub(crate) struct KeySwitchingKeyBuildHelper<'keys> {
pub(crate) key_switching_key_material: KeySwitchingKeyMaterial,
pub(crate) dest_server_key: ServerKeyView<'keys>,
pub(crate) src_server_key: Option<&'keys ServerKey>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
#[versionize(KeySwitchingKeyVersions)]
pub struct KeySwitchingKey {
pub(crate) key_switching_key_material: KeySwitchingKeyMaterial,
pub(crate) dest_server_key: ServerKey,
pub(crate) src_server_key: Option<ServerKey>,
}
impl From<KeySwitchingKeyBuildHelper<'_>> for KeySwitchingKey {
fn from(value: KeySwitchingKeyBuildHelper) -> Self {
let KeySwitchingKeyBuildHelper {
key_switching_key_material,
dest_server_key,
src_server_key,
} = value;
Self {
key_switching_key_material,
dest_server_key: dest_server_key.owned(),
src_server_key: src_server_key.map(ToOwned::to_owned),
}
}
}
enum CastCiphertext<Scalar: UnsignedInteger> {
CorrectKey(Ciphertext),
WrongKeyRequiresPBS {
ct: LweCiphertextOwned<Scalar>,
degree: Degree,
},
}
impl CastCiphertext<u64> {
fn get_cast_type_standard(
keyswitched: Ciphertext,
dest_server_key: StandardServerKeyView<'_>,
keyswitch_destination_key: EncryptionKeyChoice,
) -> Self {
match (
keyswitch_destination_key,
EncryptionKeyChoice::from(dest_server_key.atomic_pattern.kind().pbs_order()),
) {
(EncryptionKeyChoice::Big, EncryptionKeyChoice::Small) => {
let mut correct_key_ct = dest_server_key.create_trivial(0);
correct_key_ct.degree = keyswitched.degree;
let wrong_key_ct = keyswitched;
correct_key_ct.set_noise_level(wrong_key_ct.noise_level(), MaxNoiseLevel::UNKNOWN);
keyswitch_lwe_ciphertext(
&dest_server_key.atomic_pattern.key_switching_key,
&wrong_key_ct.ct,
&mut correct_key_ct.ct,
);
Self::CorrectKey(correct_key_ct)
}
(EncryptionKeyChoice::Small, EncryptionKeyChoice::Big) => {
Self::WrongKeyRequiresPBS {
ct: keyswitched.ct,
degree: keyswitched.degree,
}
}
(EncryptionKeyChoice::Big, EncryptionKeyChoice::Big)
| (EncryptionKeyChoice::Small, EncryptionKeyChoice::Small) => {
Self::CorrectKey(keyswitched)
}
}
}
}
impl CastCiphertext<u32> {
fn get_cast_type_ks32(
keyswitched: Ciphertext,
dest_server_key: KS32ServerKeyView<'_>,
keyswitch_destination_key: EncryptionKeyChoice,
) -> Self {
match (
keyswitch_destination_key,
EncryptionKeyChoice::from(dest_server_key.atomic_pattern.kind().pbs_order()),
) {
(EncryptionKeyChoice::Big | EncryptionKeyChoice::Small, EncryptionKeyChoice::Small) => {
panic!("KS32 atomic pattern only supports encryption under the big key")
}
(EncryptionKeyChoice::Big, EncryptionKeyChoice::Big) => Self::CorrectKey(keyswitched),
(EncryptionKeyChoice::Small, EncryptionKeyChoice::Big) => {
let Ok(keyswitched_modulus) = keyswitched.ct.ciphertext_modulus().try_to() else {
panic!("Ciphertext modulus after keyswitch must be <= 2**32 for the KS32 atomic pattern")
};
let shift = u64::BITS - u32::BITS;
let ap_lwe_cont = keyswitched
.ct
.as_ref()
.iter()
.map(|elem| (elem >> shift) as u32)
.collect();
let ap_lwe = LweCiphertext::from_container(ap_lwe_cont, keyswitched_modulus);
Self::WrongKeyRequiresPBS {
ct: ap_lwe,
degree: keyswitched.degree,
}
}
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct KeySwitchingKeyMaterialView<'key> {
pub(crate) key_switching_key: &'key LweKeyswitchKeyOwned<u64>,
pub(crate) cast_rshift: i8,
pub(crate) destination_key: EncryptionKeyChoice,
pub(crate) destination_atomic_pattern: KeySwitchingKeyDestinationAtomicPattern,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct KeySwitchingKeyView<'keys> {
pub(crate) key_switching_key_material: KeySwitchingKeyMaterialView<'keys>,
pub(crate) dest_server_key: ServerKeyView<'keys>,
pub(crate) src_server_key: Option<&'keys ServerKey>,
}
impl<'keys> KeySwitchingKeyBuildHelper<'keys> {
pub(crate) fn new<'input_key, InputEncryptionKey>(
input_key_pair: (InputEncryptionKey, Option<&'keys ServerKey>),
output_key_pair: (&'keys ClientKey, &'keys ServerKey),
params: ShortintKeySwitchingParameters,
) -> Self
where
InputEncryptionKey: Into<SecretEncryptionKeyView<'input_key>>,
{
ShortintEngine::with_thread_local_mut(|engine| {
Self::new_with_engine(input_key_pair, output_key_pair, params, engine)
})
}
pub(crate) fn new_with_engine<'input_key, InputEncryptionKey>(
input_key_pair: (InputEncryptionKey, Option<&'keys ServerKey>),
output_key_pair: (&'keys ClientKey, &'keys ServerKey),
params: ShortintKeySwitchingParameters,
engine: &mut ShortintEngine,
) -> Self
where
InputEncryptionKey: Into<SecretEncryptionKeyView<'input_key>>,
{
let input_secret_key: SecretEncryptionKeyView<'_> = input_key_pair.0.into();
let output_cks = output_key_pair.0;
let key_switching_key = output_cks.atomic_pattern.new_keyswitching_key_with_engine(
&input_secret_key,
params,
engine,
);
let full_message_modulus_input =
input_secret_key.carry_modulus.0 * input_secret_key.message_modulus.0;
let full_message_modulus_output = output_key_pair.0.parameters().carry_modulus().0
* output_key_pair.0.parameters().message_modulus().0;
assert!(
full_message_modulus_input.is_power_of_two()
&& full_message_modulus_output.is_power_of_two(),
"Cannot create casting key if the full messages moduli are not a power of 2"
);
if full_message_modulus_input > full_message_modulus_output {
assert!(
input_key_pair.1.is_some(),
"Trying to build a shortint::KeySwitchingKey \
going from a large modulus {full_message_modulus_input} \
to a smaller modulus {full_message_modulus_output} \
without providing a source ServerKey, this is not supported"
);
}
let dest_server_key = output_key_pair.1.as_view();
let nb_bits_input: i8 = full_message_modulus_input.ilog2().try_into().unwrap();
let nb_bits_output: i8 = full_message_modulus_output.ilog2().try_into().unwrap();
Self {
key_switching_key_material: KeySwitchingKeyMaterial {
key_switching_key,
cast_rshift: nb_bits_output - nb_bits_input,
destination_key: params.destination_key,
destination_atomic_pattern: dest_server_key.atomic_pattern.kind().into(),
},
dest_server_key,
src_server_key: input_key_pair.1,
}
}
#[cfg(test)]
pub(crate) fn as_key_switching_key_view(&self) -> KeySwitchingKeyView<'_> {
let Self {
key_switching_key_material,
dest_server_key,
src_server_key,
} = self;
KeySwitchingKeyView {
key_switching_key_material: key_switching_key_material.as_view(),
dest_server_key: *dest_server_key,
src_server_key: *src_server_key,
}
}
}
impl KeySwitchingKey {
pub fn new<'input_key, InputEncryptionKey>(
input_key_pair: (InputEncryptionKey, Option<&ServerKey>),
output_key_pair: (&ClientKey, &ServerKey),
params: ShortintKeySwitchingParameters,
) -> Self
where
InputEncryptionKey: Into<SecretEncryptionKeyView<'input_key>>,
{
KeySwitchingKeyBuildHelper::new(input_key_pair, output_key_pair, params).into()
}
pub fn as_view(&self) -> KeySwitchingKeyView<'_> {
let Self {
key_switching_key_material,
dest_server_key,
src_server_key,
} = self;
KeySwitchingKeyView {
key_switching_key_material: key_switching_key_material.as_view(),
dest_server_key: dest_server_key.as_view(),
src_server_key: src_server_key.as_ref(),
}
}
pub fn into_raw_parts(self) -> (KeySwitchingKeyMaterial, ServerKey, Option<ServerKey>) {
let Self {
key_switching_key_material,
dest_server_key,
src_server_key,
} = self;
(key_switching_key_material, dest_server_key, src_server_key)
}
pub fn from_raw_parts(
key_switching_key_material: KeySwitchingKeyMaterial,
dest_server_key: ServerKey,
src_server_key: Option<ServerKey>,
) -> Self {
match src_server_key {
Some(ref src_server_key) => {
let src_lwe_dimension = src_server_key.ciphertext_lwe_dimension();
assert_eq!(
src_lwe_dimension,
key_switching_key_material
.key_switching_key
.input_key_lwe_dimension(),
"Mismatch between the source ServerKey ciphertext LweDimension ({:?}) \
and the LweKeyswitchKey input LweDimension ({:?})",
src_lwe_dimension,
key_switching_key_material
.key_switching_key
.input_key_lwe_dimension(),
);
assert_eq!(
src_server_key.ciphertext_modulus, dest_server_key.ciphertext_modulus,
"Mismatch between the source ServerKey CiphertextModulus ({:?}) \
and the destination ServerKey CiphertextModulus ({:?})",
src_server_key.ciphertext_modulus, dest_server_key.ciphertext_modulus,
);
}
None => assert!(
key_switching_key_material.cast_rshift >= 0,
"Trying to build a shortint::KeySwitchingKey with a negative cast_rshift \
without providing a source ServerKey, this is not supported"
),
}
let dst_lwe_dimension = dest_server_key
.atomic_pattern
.ciphertext_lwe_dimension_for_key(key_switching_key_material.destination_key);
assert_eq!(
dst_lwe_dimension,
key_switching_key_material
.key_switching_key
.output_key_lwe_dimension(),
"Mismatch between the destination ServerKey ciphertext LweDimension ({:?}) \
and the LweKeyswitchKey output LweDimension ({:?})",
dst_lwe_dimension,
key_switching_key_material
.key_switching_key
.output_key_lwe_dimension(),
);
assert_eq!(
key_switching_key_material
.key_switching_key
.ciphertext_modulus(),
dest_server_key.ciphertext_modulus,
"Mismatch between the LweKeyswitchKey CiphertextModulus ({:?}) \
and the destination ServerKey CiphertextModulus ({:?})",
key_switching_key_material
.key_switching_key
.ciphertext_modulus(),
dest_server_key.ciphertext_modulus,
);
Self {
key_switching_key_material,
dest_server_key,
src_server_key,
}
}
pub fn cast(&self, input_ct: &Ciphertext) -> Ciphertext {
self.as_view().cast(input_ct)
}
}
impl<'keys> KeySwitchingKeyView<'keys> {
pub fn into_raw_parts(
self,
) -> (
KeySwitchingKeyMaterialView<'keys>,
ServerKeyView<'keys>,
Option<&'keys ServerKey>,
) {
let Self {
key_switching_key_material,
dest_server_key,
src_server_key,
} = self;
(key_switching_key_material, dest_server_key, src_server_key)
}
pub fn from_raw_parts(
key_switching_key_material: KeySwitchingKeyMaterialView<'keys>,
dest_server_key: &'keys ServerKey,
src_server_key: Option<&'keys ServerKey>,
) -> Self {
let dest_server_key = dest_server_key.as_view();
match src_server_key {
Some(src_server_key) => {
let src_lwe_dimension = src_server_key.ciphertext_lwe_dimension();
assert_eq!(
src_lwe_dimension,
key_switching_key_material
.key_switching_key
.input_key_lwe_dimension(),
"Mismatch between the source ServerKey ciphertext LweDimension ({:?}) \
and the LweKeyswitchKey input LweDimension ({:?})",
src_lwe_dimension,
key_switching_key_material
.key_switching_key
.input_key_lwe_dimension(),
);
assert_eq!(
src_server_key.ciphertext_modulus, dest_server_key.ciphertext_modulus,
"Mismatch between the source ServerKey CiphertextModulus ({:?}) \
and the destination ServerKey CiphertextModulus ({:?})",
src_server_key.ciphertext_modulus, dest_server_key.ciphertext_modulus,
);
}
None => assert!(
key_switching_key_material.cast_rshift >= 0,
"Trying to build a shortint::KeySwitchingKey with a negative cast_rshift \
without providing a source ServerKey, this is not supported"
),
}
let dst_lwe_dimension = dest_server_key
.atomic_pattern
.ciphertext_lwe_dimension_for_key(key_switching_key_material.destination_key);
assert_eq!(
dst_lwe_dimension,
key_switching_key_material
.key_switching_key
.output_key_lwe_dimension(),
"Mismatch between the destination ServerKey ciphertext LweDimension ({:?}) \
and the LweKeyswitchKey output LweDimension ({:?})",
dst_lwe_dimension,
key_switching_key_material
.key_switching_key
.output_key_lwe_dimension(),
);
assert_eq!(
key_switching_key_material
.key_switching_key
.ciphertext_modulus(),
dest_server_key
.atomic_pattern
.ciphertext_modulus_for_key(key_switching_key_material.destination_key),
"Mismatch between the LweKeyswitchKey CiphertextModulus ({:?}) \
and the destination ServerKey CiphertextModulus ({:?})",
key_switching_key_material
.key_switching_key
.ciphertext_modulus(),
dest_server_key
.atomic_pattern
.ciphertext_modulus_for_key(key_switching_key_material.destination_key),
);
Self {
key_switching_key_material,
dest_server_key,
src_server_key,
}
}
pub fn cast(&self, input_ct: &Ciphertext) -> Ciphertext {
let res = self.cast_and_apply_functions(input_ct, None);
assert_eq!(res.len(), 1);
res.into_iter().next().unwrap()
}
pub fn cast_and_apply_functions(
&self,
input_ct: &Ciphertext,
functions: Option<&[&(dyn Fn(u64) -> u64 + Sync)]>,
) -> Vec<Ciphertext> {
let output_lwe_size = self
.dest_server_key
.atomic_pattern
.ciphertext_lwe_dimension_for_key(self.key_switching_key_material.destination_key)
.to_lwe_size();
let output_ciphertext_modulus = self
.dest_server_key
.atomic_pattern
.ciphertext_modulus_for_key(self.key_switching_key_material.destination_key);
let mut keyswitched = unchecked_create_trivial_with_lwe_size(
Cleartext(0),
output_lwe_size,
self.dest_server_key.message_modulus,
self.dest_server_key.carry_modulus,
self.dest_server_key.atomic_pattern.kind(),
output_ciphertext_modulus,
);
keyswitched.set_noise_level(NoiseLevel::UNKNOWN, MaxNoiseLevel::UNKNOWN);
let cast_rshift = self.key_switching_key_material.cast_rshift;
let tmp_preprocessed: Ciphertext;
let pre_processed = match cast_rshift.cmp(&0) {
Ordering::Less => {
let src_server_key = self.src_server_key.as_ref().expect(
"No source server key in shortint::KeySwitchingKey \
which is required when casting to a smaller message modulus",
);
let acc = src_server_key.generate_lookup_table(|n| {
(n << -cast_rshift) % (input_ct.carry_modulus.0 * input_ct.message_modulus.0)
});
tmp_preprocessed = src_server_key.apply_lookup_table(input_ct, &acc);
&tmp_preprocessed
}
Ordering::Equal | Ordering::Greater => input_ct,
};
keyswitch_lwe_ciphertext(
self.key_switching_key_material.key_switching_key,
&pre_processed.ct,
&mut keyswitched.ct,
);
keyswitched.degree = pre_processed.degree;
match self.dest_server_key.atomic_pattern {
AtomicPatternServerKey::Standard(std_ap) => {
let std_key = StandardServerKeyView::try_from(self.dest_server_key).unwrap();
let cast_type = CastCiphertext::get_cast_type_standard(
keyswitched,
std_key,
self.key_switching_key_material.destination_key,
);
self.apply_cast_pbs_after_keyswitch(
cast_rshift,
cast_type,
functions,
&std_ap.bootstrapping_key,
)
}
AtomicPatternServerKey::KeySwitch32(ks32_ap) => {
let ks32_key = KS32ServerKeyView::try_from(self.dest_server_key).unwrap();
let cast_type = CastCiphertext::get_cast_type_ks32(
keyswitched,
ks32_key,
self.key_switching_key_material.destination_key,
);
self.apply_cast_pbs_after_keyswitch(
cast_rshift,
cast_type,
functions,
&ks32_ap.bootstrapping_key,
)
}
AtomicPatternServerKey::Dynamic(_) => {
panic!("Dynamic atomic pattern does not support key switching")
}
}
}
fn apply_cast_pbs_after_keyswitch<KeySwitchedScalar>(
&self,
cast_rshift: i8,
ct_to_cast: CastCiphertext<KeySwitchedScalar>,
functions: Option<&[&(dyn Fn(u64) -> u64 + Sync)]>,
compute_bsk: &ShortintBootstrappingKey<KeySwitchedScalar>,
) -> Vec<Ciphertext>
where
KeySwitchedScalar: UnsignedTorus + CastInto<usize> + CastFrom<usize>,
{
let output_ciphertext_count = functions.map_or_else(|| 1, |x| x.len());
let identity_fn_array: &[&(dyn Fn(u64) -> u64 + Sync)] = &[&|x: u64| x];
let functions_to_use = functions.unwrap_or(identity_fn_array);
let using_user_provided_functions = functions.is_some();
let using_identity_lut = !using_user_provided_functions;
let mut output_cts = vec![self.dest_server_key.create_trivial(0); output_ciphertext_count];
match cast_rshift.cmp(&0) {
Ordering::Equal => {
match ct_to_cast {
CastCiphertext::CorrectKey(ciphertext) => {
output_cts
.par_iter_mut()
.zip(functions_to_use.par_iter())
.for_each(|(correct_key_ct, function)| {
let acc = self.dest_server_key.generate_lookup_table(function);
*correct_key_ct =
self.dest_server_key.apply_lookup_table(&ciphertext, &acc);
if using_identity_lut {
correct_key_ct.degree = ciphertext.degree;
}
});
}
CastCiphertext::WrongKeyRequiresPBS {
ct: wrong_key_ct,
degree: degree_after_keyswitch,
} => {
output_cts
.par_iter_mut()
.zip(functions_to_use.par_iter())
.for_each(|(correct_key_ct, function)| {
ShortintEngine::with_thread_local_mut(|engine| {
let buffers = engine.get_computation_buffers();
let acc = self.dest_server_key.generate_lookup_table(function);
apply_programmable_bootstrap(
compute_bsk,
&wrong_key_ct,
&mut correct_key_ct.ct,
&acc.acc,
buffers,
);
if using_user_provided_functions {
correct_key_ct.degree = acc.degree;
} else {
correct_key_ct.degree = degree_after_keyswitch;
}
correct_key_ct.set_noise_level_to_nominal();
});
});
}
}
}
Ordering::Greater => {
match ct_to_cast {
CastCiphertext::CorrectKey(ciphertext) => {
output_cts
.par_iter_mut()
.zip(functions_to_use.par_iter())
.for_each(|(correct_key_ct, function)| {
let acc = self
.dest_server_key
.generate_lookup_table(|n| function(n >> cast_rshift));
*correct_key_ct =
self.dest_server_key.apply_lookup_table(&ciphertext, &acc);
});
}
CastCiphertext::WrongKeyRequiresPBS {
ct: wrong_key_ct,
degree: _,
} => {
output_cts
.par_iter_mut()
.zip(functions_to_use.par_iter())
.for_each(|(correct_key_ct, function)| {
ShortintEngine::with_thread_local_mut(|engine| {
let buffers = engine.get_computation_buffers();
let acc = self.dest_server_key.generate_lookup_table(|n| {
function(n >> cast_rshift)
});
apply_programmable_bootstrap(
compute_bsk,
&wrong_key_ct,
&mut correct_key_ct.ct,
&acc.acc,
buffers,
);
correct_key_ct.degree = acc.degree;
correct_key_ct.set_noise_level_to_nominal();
});
});
}
}
}
Ordering::Less => {
match ct_to_cast {
CastCiphertext::CorrectKey(ciphertext) => {
output_cts
.par_iter_mut()
.zip(functions_to_use.par_iter())
.for_each(|(correct_key_ct, function)| {
let acc = self.dest_server_key.generate_lookup_table(function);
*correct_key_ct =
self.dest_server_key.apply_lookup_table(&ciphertext, &acc);
if using_user_provided_functions {
correct_key_ct.degree = acc.degree;
} else {
let new_degree =
Degree::new(ciphertext.degree.get() >> -cast_rshift);
correct_key_ct.degree = new_degree;
}
});
}
CastCiphertext::WrongKeyRequiresPBS {
ct: wrong_key_ct,
degree: degree_after_keyswitch,
} => {
output_cts
.par_iter_mut()
.zip(functions_to_use.par_iter())
.for_each(|(correct_key_ct, function)| {
ShortintEngine::with_thread_local_mut(|engine| {
let buffers = engine.get_computation_buffers();
let acc = self.dest_server_key.generate_lookup_table(function);
apply_programmable_bootstrap(
compute_bsk,
&wrong_key_ct,
&mut correct_key_ct.ct,
&acc.acc,
buffers,
);
if using_user_provided_functions {
correct_key_ct.degree = acc.degree;
} else {
let new_degree = Degree::new(
degree_after_keyswitch.get() >> -cast_rshift,
);
correct_key_ct.degree = new_degree;
}
correct_key_ct.set_noise_level_to_nominal();
});
});
}
}
}
}
output_cts
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Versionize)]
#[versionize(CompressedKeySwitchingKeyMaterialVersions)]
pub struct CompressedKeySwitchingKeyMaterial {
pub(crate) key_switching_key: SeededLweKeyswitchKeyOwned<u64>,
pub(crate) cast_rshift: i8,
pub(crate) destination_key: EncryptionKeyChoice,
pub(crate) destination_atomic_pattern: KeySwitchingKeyDestinationAtomicPattern,
}
impl CompressedKeySwitchingKeyMaterial {
pub fn decompress(&self) -> KeySwitchingKeyMaterial {
let key_switching_key = self
.key_switching_key
.as_view()
.par_decompress_into_lwe_keyswitch_key();
KeySwitchingKeyMaterial {
key_switching_key,
cast_rshift: self.cast_rshift,
destination_key: self.destination_key,
destination_atomic_pattern: self.destination_atomic_pattern,
}
}
pub fn from_raw_parts(
key_switching_key: SeededLweKeyswitchKeyOwned<u64>,
cast_rshift: i8,
destination_key: EncryptionKeyChoice,
destination_atomic_pattern: KeySwitchingKeyDestinationAtomicPattern,
) -> Self {
Self {
key_switching_key,
cast_rshift,
destination_key,
destination_atomic_pattern,
}
}
pub fn into_raw_parts(
self,
) -> (
SeededLweKeyswitchKeyOwned<u64>,
i8,
EncryptionKeyChoice,
KeySwitchingKeyDestinationAtomicPattern,
) {
let Self {
key_switching_key,
cast_rshift,
destination_key,
destination_atomic_pattern,
} = self;
(
key_switching_key,
cast_rshift,
destination_key,
destination_atomic_pattern,
)
}
}
pub(crate) struct CompressedKeySwitchingKeyBuildHelper<'keys> {
pub(crate) key_switching_key_material: CompressedKeySwitchingKeyMaterial,
pub(crate) dest_server_key: &'keys CompressedServerKey,
pub(crate) src_server_key: Option<&'keys CompressedServerKey>,
}
#[derive(Clone, Debug, Serialize, Deserialize, Versionize)]
#[versionize(CompressedKeySwitchingKeyVersions)]
pub struct CompressedKeySwitchingKey {
pub(crate) key_switching_key_material: CompressedKeySwitchingKeyMaterial,
pub(crate) dest_server_key: CompressedServerKey,
pub(crate) src_server_key: Option<CompressedServerKey>,
}
impl From<CompressedKeySwitchingKeyBuildHelper<'_>> for CompressedKeySwitchingKey {
fn from(value: CompressedKeySwitchingKeyBuildHelper) -> Self {
let CompressedKeySwitchingKeyBuildHelper {
key_switching_key_material,
dest_server_key,
src_server_key,
} = value;
Self {
key_switching_key_material,
dest_server_key: dest_server_key.to_owned(),
src_server_key: src_server_key.map(ToOwned::to_owned),
}
}
}
impl<'keys> CompressedKeySwitchingKeyBuildHelper<'keys> {
pub(crate) fn new<'input_key, InputEncryptionKey>(
input_key_pair: (InputEncryptionKey, Option<&'keys CompressedServerKey>),
output_key_pair: (&'keys ClientKey, &'keys CompressedServerKey),
params: ShortintKeySwitchingParameters,
) -> Self
where
InputEncryptionKey: Into<SecretEncryptionKeyView<'input_key>>,
{
let input_secret_key: SecretEncryptionKeyView<'_> = input_key_pair.0.into();
let output_cks = output_key_pair.0;
let key_switching_key = ShortintEngine::with_thread_local_mut(|engine| {
output_cks
.atomic_pattern
.new_seeded_keyswitching_key_with_engine(&input_secret_key, params, engine)
});
let full_message_modulus_input =
input_secret_key.carry_modulus.0 * input_secret_key.message_modulus.0;
let full_message_modulus_output = output_key_pair.0.parameters().carry_modulus().0
* output_key_pair.0.parameters().message_modulus().0;
assert!(
full_message_modulus_input.is_power_of_two()
&& full_message_modulus_output.is_power_of_two(),
"Cannot create casting key if the full messages moduli are not a power of 2"
);
if full_message_modulus_input > full_message_modulus_output {
assert!(
input_key_pair.1.is_some(),
"Trying to build a shortint::KeySwitchingKey \
going from a large modulus {full_message_modulus_input} \
to a smaller modulus {full_message_modulus_output} \
without providing a source ServerKey, this is not supported"
);
}
let nb_bits_input: i8 = full_message_modulus_input.ilog2().try_into().unwrap();
let nb_bits_output: i8 = full_message_modulus_output.ilog2().try_into().unwrap();
Self {
key_switching_key_material: CompressedKeySwitchingKeyMaterial {
key_switching_key,
cast_rshift: nb_bits_output - nb_bits_input,
destination_key: params.destination_key,
destination_atomic_pattern: output_cks.atomic_pattern.kind().into(),
},
dest_server_key: output_key_pair.1,
src_server_key: input_key_pair.1,
}
}
}
impl CompressedKeySwitchingKey {
pub fn new<'input_key, InputEncryptionKey>(
input_key_pair: (InputEncryptionKey, Option<&CompressedServerKey>),
output_key_pair: (&ClientKey, &CompressedServerKey),
params: ShortintKeySwitchingParameters,
) -> Self
where
InputEncryptionKey: Into<SecretEncryptionKeyView<'input_key>>,
{
CompressedKeySwitchingKeyBuildHelper::new(input_key_pair, output_key_pair, params).into()
}
pub fn decompress(&self) -> KeySwitchingKey {
KeySwitchingKey {
key_switching_key_material: self.key_switching_key_material.decompress(),
dest_server_key: self.dest_server_key.decompress(),
src_server_key: self
.src_server_key
.as_ref()
.map(CompressedServerKey::decompress),
}
}
pub fn into_raw_parts(
self,
) -> (
CompressedKeySwitchingKeyMaterial,
CompressedServerKey,
Option<CompressedServerKey>,
) {
let Self {
key_switching_key_material,
dest_server_key,
src_server_key,
} = self;
(key_switching_key_material, dest_server_key, src_server_key)
}
pub fn from_raw_parts(
key_switching_key_material: CompressedKeySwitchingKeyMaterial,
dest_server_key: CompressedServerKey,
src_server_key: Option<CompressedServerKey>,
) -> Self {
match src_server_key {
Some(ref src_server_key) => {
let src_lwe_dimension = src_server_key.ciphertext_lwe_dimension();
assert_eq!(
src_lwe_dimension,
key_switching_key_material
.key_switching_key
.input_key_lwe_dimension(),
"Mismatch between the source CompressedServerKey ciphertext LweDimension ({:?}) \
and the SeededLweKeyswitchKey input LweDimension ({:?})",
src_lwe_dimension,
key_switching_key_material
.key_switching_key
.input_key_lwe_dimension(),
);
assert_eq!(
src_server_key.ciphertext_modulus(),
dest_server_key.ciphertext_modulus(),
"Mismatch between the source CompressedServerKey CiphertextModulus ({:?}) \
and the destination CompressedServerKey CiphertextModulus ({:?})",
src_server_key.ciphertext_modulus(),
dest_server_key.ciphertext_modulus(),
);
}
None => assert!(
key_switching_key_material.cast_rshift >= 0,
"Trying to build a shortint::CompressedKeySwitchingKey with a negative cast_rshift \
without providing a source CompressedServerKey, this is not supported"
),
}
let std_dest_server_key = dest_server_key
.as_compressed_standard_atomic_pattern_server_key()
.expect(
"Trying to build a shortint::CompressedKeySwitchingKey \
with an unsupported atomic pattern",
);
let dest_bootstrapping_key = std_dest_server_key.bootstrapping_key();
let dst_lwe_dimension = match key_switching_key_material.destination_key {
EncryptionKeyChoice::Big => dest_bootstrapping_key.output_lwe_dimension(),
EncryptionKeyChoice::Small => dest_bootstrapping_key.input_lwe_dimension(),
};
assert_eq!(
dst_lwe_dimension,
key_switching_key_material
.key_switching_key
.output_key_lwe_dimension(),
"Mismatch between the destination CompressedServerKey ciphertext LweDimension ({:?}) \
and the SeededLweKeyswitchKey output LweDimension ({:?})",
dst_lwe_dimension,
key_switching_key_material
.key_switching_key
.output_key_lwe_dimension(),
);
assert_eq!(
key_switching_key_material
.key_switching_key
.ciphertext_modulus(),
dest_server_key.ciphertext_modulus(),
"Mismatch between the SeededLweKeyswitchKey CiphertextModulus ({:?}) \
and the destination CompressedServerKey CiphertextModulus ({:?})",
key_switching_key_material
.key_switching_key
.ciphertext_modulus(),
dest_server_key.ciphertext_modulus(),
);
Self {
key_switching_key_material,
dest_server_key,
src_server_key,
}
}
}
pub struct KeySwitchingKeyConformanceParams {
pub keyswitch_key_conformance_params: LweKeyswitchKeyConformanceParams<u64>,
pub cast_rshift: i8,
pub destination_key: EncryptionKeyChoice,
pub destination_atomic_pattern: KeySwitchingKeyDestinationAtomicPattern,
}
impl ParameterSetConformant for KeySwitchingKeyMaterial {
type ParameterSet = KeySwitchingKeyConformanceParams;
fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
let Self {
key_switching_key,
cast_rshift,
destination_key,
destination_atomic_pattern,
} = self;
key_switching_key.is_conformant(¶meter_set.keyswitch_key_conformance_params)
&& *cast_rshift == parameter_set.cast_rshift
&& *destination_key == parameter_set.destination_key
&& *destination_atomic_pattern == parameter_set.destination_atomic_pattern
}
}
impl ParameterSetConformant for CompressedKeySwitchingKeyMaterial {
type ParameterSet = KeySwitchingKeyConformanceParams;
fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
let Self {
key_switching_key,
cast_rshift,
destination_key,
destination_atomic_pattern,
} = self;
key_switching_key.is_conformant(¶meter_set.keyswitch_key_conformance_params)
&& *cast_rshift == parameter_set.cast_rshift
&& *destination_key == parameter_set.destination_key
&& *destination_atomic_pattern == parameter_set.destination_atomic_pattern
}
}