use crate::core_crypto::gpu::lwe_bootstrap_key::{
CudaLweBootstrapKey, CudaModulusSwitchNoiseReductionConfiguration,
};
use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey;
use crate::core_crypto::gpu::lwe_multi_bit_bootstrap_key::CudaLweMultiBitBootstrapKey;
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::{
allocate_and_generate_new_lwe_keyswitch_key, par_allocate_and_generate_new_lwe_bootstrap_key,
par_allocate_and_generate_new_lwe_multi_bit_bootstrap_key, GlweSize, LweBootstrapKeyOwned,
LweDimension, LweMultiBitBootstrapKeyOwned, UnsignedInteger,
};
use crate::high_level_api::keys::expanded::{
ShortintExpandedBootstrappingKey, ShortintExpandedServerKey,
};
use crate::integer::server_key::num_bits_to_represent_unsigned_value;
use crate::integer::ClientKey;
use crate::shortint::atomic_pattern::expanded::{
ExpandedAtomicPatternServerKey, ExpandedKS32AtomicPatternServerKey,
ExpandedStandardAtomicPatternServerKey,
};
use crate::shortint::ciphertext::{MaxDegree, MaxNoiseLevel};
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::oprf::ExpandedOprfBootstrappingKey;
use crate::shortint::parameters::ModulusSwitchType;
use crate::shortint::prelude::PolynomialSize;
use crate::shortint::{CarryModulus, CiphertextModulus, MessageModulus, PBSOrder};
pub use radix::{CudaOprfServerKey, CudaOprfServerKeyView, GenericCudaOprfServerKey};
mod radix;
pub enum CudaBootstrappingKey<Scalar: UnsignedInteger> {
Classic(CudaLweBootstrapKey),
MultiBit(CudaLweMultiBitBootstrapKey<Scalar>),
}
impl<Scalar> CudaBootstrappingKey<Scalar>
where
Scalar: UnsignedInteger,
{
pub(crate) fn from_expanded_bootstrapping_key<ModSwitchScalar>(
expanded_bsk: &ShortintExpandedBootstrappingKey<Scalar, ModSwitchScalar>,
streams: &CudaStreams,
) -> crate::Result<Self>
where
ModSwitchScalar: UnsignedInteger,
{
match expanded_bsk {
ShortintExpandedBootstrappingKey::Classic {
bsk,
modulus_switch_noise_reduction_key,
} => {
let modulus_switch_noise_reduction_configuration =
CudaModulusSwitchNoiseReductionConfiguration::from_modulus_switch_configuration(
modulus_switch_noise_reduction_key,
)?;
let d_bootstrap_key = CudaLweBootstrapKey::from_lwe_bootstrap_key(
bsk,
modulus_switch_noise_reduction_configuration,
streams,
);
Ok(Self::Classic(d_bootstrap_key))
}
ShortintExpandedBootstrappingKey::MultiBit {
bsk,
thread_count: _,
deterministic_execution: _,
} => {
let d_bootstrap_key =
CudaLweMultiBitBootstrapKey::from_lwe_multi_bit_bootstrap_key(bsk, streams);
Ok(Self::MultiBit(d_bootstrap_key))
}
}
}
pub(crate) fn polynomial_size(&self) -> PolynomialSize {
match self {
Self::Classic(bsk) => bsk.polynomial_size,
Self::MultiBit(mb_bsk) => mb_bsk.polynomial_size,
}
}
pub(crate) fn input_lwe_dimension(&self) -> LweDimension {
match self {
Self::Classic(bsk) => bsk.input_lwe_dimension,
Self::MultiBit(mb_bsk) => mb_bsk.input_lwe_dimension,
}
}
pub(crate) fn output_lwe_dimension(&self) -> LweDimension {
match self {
Self::Classic(bsk) => bsk.output_lwe_dimension(),
Self::MultiBit(mb_bsk) => mb_bsk.output_lwe_dimension(),
}
}
pub(crate) fn glwe_size(&self) -> GlweSize {
match self {
Self::Classic(bsk) => bsk.glwe_dimension().to_glwe_size(),
Self::MultiBit(mb_bsk) => mb_bsk.glwe_dimension().to_glwe_size(),
}
}
}
impl CudaBootstrappingKey<u64> {
pub(crate) fn from_expanded_oprf_server_key(
expanded_bsk: &ExpandedOprfBootstrappingKey,
streams: &CudaStreams,
) -> Self {
match expanded_bsk {
ExpandedOprfBootstrappingKey::Classic { bsk, .. } => {
let d_bootstrap_key =
CudaLweBootstrapKey::from_lwe_bootstrap_key(bsk, None, streams);
Self::Classic(d_bootstrap_key)
}
ExpandedOprfBootstrappingKey::MultiBit {
bsk,
thread_count: _,
deterministic_execution: _,
} => {
let d_bootstrap_key =
CudaLweMultiBitBootstrapKey::from_lwe_multi_bit_bootstrap_key(bsk, streams);
Self::MultiBit(d_bootstrap_key)
}
}
}
}
pub enum CudaDynamicKeyswitchingKey {
Standard(CudaLweKeyswitchKey<u64>),
KeySwitch32(CudaLweKeyswitchKey<u32>),
}
pub struct CudaServerKey {
pub key_switching_key: CudaDynamicKeyswitchingKey,
pub bootstrapping_key: CudaBootstrappingKey<u64>, pub message_modulus: MessageModulus,
pub carry_modulus: CarryModulus,
pub max_degree: MaxDegree,
pub max_noise_level: MaxNoiseLevel,
pub ciphertext_modulus: CiphertextModulus,
pub pbs_order: PBSOrder,
}
impl CudaServerKey {
pub fn new<C>(cks: C, streams: &CudaStreams) -> Self
where
C: AsRef<ClientKey>,
{
let client_key = cks.as_ref();
let max_degree = MaxDegree::integer_radix_server_key(
client_key.key.parameters().message_modulus(),
client_key.key.parameters().carry_modulus(),
);
Self::new_server_key_with_max_degree(client_key, max_degree, streams)
}
pub(crate) fn new_server_key_with_max_degree(
cks: &ClientKey,
max_degree: MaxDegree,
streams: &CudaStreams,
) -> Self {
let mut engine = ShortintEngine::new();
let AtomicPatternClientKey::Standard(std_cks) = &cks.key.atomic_pattern else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let pbs_params_base = std_cks.parameters;
let d_bootstrapping_key = match pbs_params_base {
crate::shortint::PBSParameters::PBS(pbs_params) => {
let h_bootstrap_key: LweBootstrapKeyOwned<u64> =
par_allocate_and_generate_new_lwe_bootstrap_key(
&std_cks.lwe_secret_key,
&std_cks.glwe_secret_key,
pbs_params.pbs_base_log,
pbs_params.pbs_level,
pbs_params.glwe_noise_distribution,
pbs_params.ciphertext_modulus,
&mut engine.encryption_generator,
);
let modulus_switch_noise_reduction_configuration =
match pbs_params.modulus_switch_noise_reduction_params {
ModulusSwitchType::Standard => None,
ModulusSwitchType::DriftTechniqueNoiseReduction(
_modulus_switch_noise_reduction_params,
) => {
panic!("Drift noise reduction is not supported on GPU")
}
ModulusSwitchType::CenteredMeanNoiseReduction => {
Some(CudaModulusSwitchNoiseReductionConfiguration::Centered)
}
};
let d_bootstrap_key = CudaLweBootstrapKey::from_lwe_bootstrap_key(
&h_bootstrap_key,
modulus_switch_noise_reduction_configuration,
streams,
);
CudaBootstrappingKey::Classic(d_bootstrap_key)
}
crate::shortint::PBSParameters::MultiBitPBS(pbs_params) => {
let h_bootstrap_key: LweMultiBitBootstrapKeyOwned<u64> =
par_allocate_and_generate_new_lwe_multi_bit_bootstrap_key(
&std_cks.lwe_secret_key,
&std_cks.glwe_secret_key,
pbs_params.pbs_base_log,
pbs_params.pbs_level,
pbs_params.grouping_factor,
pbs_params.glwe_noise_distribution,
pbs_params.ciphertext_modulus,
&mut engine.encryption_generator,
);
let d_bootstrap_key = CudaLweMultiBitBootstrapKey::from_lwe_multi_bit_bootstrap_key(
&h_bootstrap_key,
streams,
);
CudaBootstrappingKey::MultiBit(d_bootstrap_key)
}
};
let h_key_switching_key = allocate_and_generate_new_lwe_keyswitch_key(
&std_cks.large_lwe_secret_key(),
&std_cks.small_lwe_secret_key(),
std_cks.parameters.ks_base_log(),
std_cks.parameters.ks_level(),
std_cks.parameters.lwe_noise_distribution(),
std_cks.parameters.ciphertext_modulus(),
&mut engine.encryption_generator,
);
let d_key_switching_key =
CudaLweKeyswitchKey::from_lwe_keyswitch_key(&h_key_switching_key, streams);
assert!(matches!(
std_cks.parameters.encryption_key_choice().into(),
PBSOrder::KeyswitchBootstrap
));
Self {
key_switching_key: CudaDynamicKeyswitchingKey::Standard(d_key_switching_key),
bootstrapping_key: d_bootstrapping_key,
message_modulus: std_cks.parameters.message_modulus(),
carry_modulus: std_cks.parameters.carry_modulus(),
max_degree,
max_noise_level: std_cks.parameters.max_noise_level(),
ciphertext_modulus: std_cks.parameters.ciphertext_modulus(),
pbs_order: std_cks.parameters.encryption_key_choice().into(),
}
}
pub fn decompress_from_cpu(
cpu_key: &crate::integer::CompressedServerKey,
streams: &CudaStreams,
) -> Self {
let crate::shortint::CompressedServerKey {
compressed_ap_server_key,
message_modulus,
carry_modulus,
max_degree,
max_noise_level,
} = &cpu_key.key;
let expanded_ap = compressed_ap_server_key.expand();
let expanded = ShortintExpandedServerKey {
atomic_pattern: expanded_ap,
message_modulus: *message_modulus,
carry_modulus: *carry_modulus,
max_degree: *max_degree,
max_noise_level: *max_noise_level,
ciphertext_modulus: compressed_ap_server_key.ciphertext_modulus(),
};
Self::from_expanded_server_key(&expanded, streams).expect("Unsupported configuration")
}
pub(crate) fn from_expanded_server_key(
expanded: &ShortintExpandedServerKey,
streams: &CudaStreams,
) -> crate::Result<Self> {
let message_modulus = expanded.message_modulus;
let carry_modulus = expanded.carry_modulus;
let max_degree = expanded.max_degree;
let max_noise_level = expanded.max_noise_level;
match &expanded.atomic_pattern {
ExpandedAtomicPatternServerKey::Standard(std_key) => {
let ExpandedStandardAtomicPatternServerKey {
key_switching_key,
bootstrapping_key,
pbs_order,
} = std_key;
let ciphertext_modulus = key_switching_key.ciphertext_modulus();
let cuda_key_switching_key =
CudaLweKeyswitchKey::from_lwe_keyswitch_key(key_switching_key, streams);
let cuda_bootstrapping_key = CudaBootstrappingKey::from_expanded_bootstrapping_key(
bootstrapping_key,
streams,
)?;
Ok(Self {
key_switching_key: CudaDynamicKeyswitchingKey::Standard(cuda_key_switching_key),
bootstrapping_key: cuda_bootstrapping_key,
message_modulus,
carry_modulus,
max_degree,
max_noise_level,
ciphertext_modulus,
pbs_order: *pbs_order,
})
}
ExpandedAtomicPatternServerKey::KeySwitch32(ks32_key) => {
let ExpandedKS32AtomicPatternServerKey {
key_switching_key,
bootstrapping_key,
ciphertext_modulus,
} = ks32_key;
let cuda_key_switching_key =
CudaLweKeyswitchKey::from_lwe_keyswitch_key(key_switching_key, streams);
let cuda_bootstrapping_key = CudaBootstrappingKey::from_expanded_bootstrapping_key(
bootstrapping_key,
streams,
)?;
Ok(Self {
key_switching_key: CudaDynamicKeyswitchingKey::KeySwitch32(
cuda_key_switching_key,
),
bootstrapping_key: cuda_bootstrapping_key,
message_modulus,
carry_modulus,
max_degree,
max_noise_level,
ciphertext_modulus: *ciphertext_modulus,
pbs_order: PBSOrder::KeyswitchBootstrap,
})
}
}
}
pub(crate) fn num_blocks_to_represent_unsigned_value<Clear>(&self, clear: Clear) -> usize
where
Clear: UnsignedInteger,
{
let num_bits_to_represent_output_value = num_bits_to_represent_unsigned_value(clear);
let num_bits_in_message = self.message_modulus.0.ilog2();
num_bits_to_represent_output_value.div_ceil(num_bits_in_message as usize)
}
}