pub mod comparator;
pub(crate) mod crt;
mod crt_parallel;
pub(crate) mod radix;
pub(crate) mod radix_parallel;
pub use radix_parallel::kv_store::{CompressedKVStore, KVStore};
use super::backward_compatibility::server_key::{CompressedServerKeyVersions, ServerKeyVersions};
use crate::conformance::ParameterSetConformant;
use crate::core_crypto::prelude::UnsignedInteger;
use crate::integer::client_key::ClientKey;
use crate::shortint::atomic_pattern::AtomicPatternParameters;
use crate::shortint::ciphertext::{Degree, MaxDegree};
pub use crate::shortint::CheckError;
use crate::shortint::{CarryModulus, MessageModulus};
pub use radix::scalar_mul::ScalarMultiplier;
pub use radix::scalar_sub::TwosComplementNegation;
pub use radix_parallel::{MatchValues, MiniUnsignedInteger, Reciprocable};
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;
#[derive(Serialize, Deserialize, Clone, Versionize)]
#[versionize(ServerKeyVersions)]
pub struct ServerKey {
pub(crate) key: crate::shortint::ServerKey,
}
impl From<ServerKey> for crate::shortint::ServerKey {
fn from(key: ServerKey) -> Self {
key.key
}
}
impl MaxDegree {
pub(crate) fn integer_radix_server_key(
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
) -> Self {
let full_max_degree = message_modulus.0 * carry_modulus.0 - 1;
let carry_max_degree = carry_modulus.0 - 1;
Self::new(full_max_degree - carry_max_degree)
}
}
impl MaxDegree {
fn integer_crt_server_key(
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
) -> Self {
let full_max_degree = message_modulus.0 * carry_modulus.0 - 1;
Self::new(full_max_degree)
}
}
impl ServerKey {
pub fn new_radix_server_key<C>(cks: C) -> Self
where
C: AsRef<ClientKey>,
{
let client_key = cks.as_ref();
let max_degree = MaxDegree::integer_radix_server_key(
client_key.key.parameters().message_modulus(),
client_key.key.parameters().carry_modulus(),
);
let sks = crate::shortint::server_key::ServerKey::new_with_max_degree(
&client_key.key,
max_degree,
);
Self { key: sks }
}
pub fn new_crt_server_key<C>(cks: C) -> Self
where
C: AsRef<ClientKey>,
{
let client_key = cks.as_ref();
let max_degree = MaxDegree::integer_crt_server_key(
client_key.key.parameters().message_modulus(),
client_key.key.parameters().carry_modulus(),
);
let sks = crate::shortint::server_key::ServerKey::new_with_max_degree(
&client_key.key,
max_degree,
);
Self { key: sks }
}
pub fn new_radix_server_key_from_shortint(
mut key: crate::shortint::server_key::ServerKey,
) -> Self {
let max_degree =
MaxDegree::integer_radix_server_key(key.message_modulus, key.carry_modulus);
key.max_degree = max_degree;
Self { key }
}
pub fn new_crt_server_key_from_shortint(
mut key: crate::shortint::server_key::ServerKey,
) -> Self {
key.max_degree = MaxDegree::integer_crt_server_key(key.message_modulus, key.carry_modulus);
Self { key }
}
pub fn into_raw_parts(self) -> crate::shortint::ServerKey {
self.key
}
pub fn from_raw_parts(key: crate::shortint::ServerKey) -> Self {
Self { key }
}
pub fn deterministic_pbs_execution(&self) -> bool {
self.key.deterministic_execution()
}
pub fn set_deterministic_pbs_execution(&mut self, new_deterministic_execution: bool) {
self.key
.set_deterministic_execution(new_deterministic_execution);
}
pub fn message_modulus(&self) -> MessageModulus {
self.key.message_modulus
}
pub fn carry_modulus(&self) -> CarryModulus {
self.key.carry_modulus
}
pub fn num_blocks_to_represent_unsigned_value<Clear>(&self, clear: Clear) -> usize
where
Clear: UnsignedInteger,
{
let num_bits_to_represent_output_value = num_bits_to_represent_unsigned_value(clear);
let num_bits_in_message = self.message_modulus().0.ilog2();
num_bits_to_represent_output_value.div_ceil(num_bits_in_message as usize)
}
pub(crate) fn max_sum_size(&self, degree: Degree) -> usize {
let max_degree =
MaxDegree::from_msg_carry_modulus(self.message_modulus(), self.carry_modulus());
let max_sum_to_full_carry = max_degree.get() / degree.get();
max_sum_to_full_carry.min(self.key.max_noise_level.get()) as usize
}
}
impl AsRef<crate::shortint::ServerKey> for ServerKey {
fn as_ref(&self) -> &crate::shortint::ServerKey {
&self.key
}
}
#[derive(Clone, Serialize, Deserialize, Versionize)]
#[versionize(CompressedServerKeyVersions)]
pub struct CompressedServerKey {
pub(crate) key: crate::shortint::CompressedServerKey,
}
impl CompressedServerKey {
pub fn new_radix_compressed_server_key(client_key: &ClientKey) -> Self {
let max_degree = MaxDegree::integer_radix_server_key(
client_key.key.parameters().message_modulus(),
client_key.key.parameters().carry_modulus(),
);
let key =
crate::shortint::CompressedServerKey::new_with_max_degree(&client_key.key, max_degree);
Self { key }
}
pub fn new_crt_compressed_server_key(client_key: &ClientKey) -> Self {
let key = crate::shortint::CompressedServerKey::new(&client_key.key);
Self { key }
}
pub fn decompress(&self) -> ServerKey {
ServerKey {
key: self.key.decompress(),
}
}
pub fn into_raw_parts(self) -> crate::shortint::CompressedServerKey {
self.key
}
pub fn from_raw_parts(key: crate::shortint::CompressedServerKey) -> Self {
Self { key }
}
}
pub fn num_bits_to_represent_unsigned_value<Clear>(clear: Clear) -> usize
where
Clear: UnsignedInteger,
{
if clear == Clear::MAX {
Clear::BITS
} else {
let bits = (clear + Clear::ONE).ceil_ilog2() as usize;
if bits == 0 {
1
} else {
bits
}
}
}
impl ParameterSetConformant for ServerKey {
type ParameterSet = AtomicPatternParameters;
fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
let Self { key } = self;
let expected_max_degree = MaxDegree::integer_radix_server_key(
parameter_set.message_modulus(),
parameter_set.carry_modulus(),
);
key.is_conformant(&(*parameter_set, expected_max_degree))
}
}
impl ParameterSetConformant for CompressedServerKey {
type ParameterSet = AtomicPatternParameters;
fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
let Self { key } = self;
let expected_max_degree = MaxDegree::integer_radix_server_key(
parameter_set.message_modulus(),
parameter_set.carry_modulus(),
);
key.is_conformant(&(*parameter_set, expected_max_degree))
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::integer::RadixClientKey;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
#[test]
fn test_compressed_server_key_max_degree() {
{
let cks = ClientKey::new(
crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
);
let expected_radix_max_degree = MaxDegree::new(12);
let sks = ServerKey::new_radix_server_key(&cks);
assert_eq!(sks.key.max_degree, expected_radix_max_degree);
let csks = CompressedServerKey::new_radix_compressed_server_key(&cks);
assert_eq!(csks.key.max_degree, expected_radix_max_degree);
let decompressed_sks: ServerKey = csks.decompress();
assert_eq!(decompressed_sks.key.max_degree, expected_radix_max_degree);
}
{
let cks = ClientKey::new(
crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
);
let expected_crt_max_degree = MaxDegree::new(15);
let sks = ServerKey::new_crt_server_key(&cks);
assert_eq!(sks.key.max_degree, expected_crt_max_degree);
let csks = CompressedServerKey::new_crt_compressed_server_key(&cks);
assert_eq!(csks.key.max_degree, expected_crt_max_degree);
let decompressed_sks: ServerKey = csks.decompress();
assert_eq!(decompressed_sks.key.max_degree, expected_crt_max_degree);
}
{
let client_key = RadixClientKey::new(PARAM_MESSAGE_2_CARRY_2, 14);
let compressed_eval_key =
CompressedServerKey::new_radix_compressed_server_key(client_key.as_ref());
let evaluation_key = compressed_eval_key.decompress();
let modulus = (client_key.parameters().message_modulus().0 as u128)
.pow(client_key.num_blocks() as u32);
let mut ct = client_key.encrypt(modulus - 1);
let mut res_ct = ct.clone();
for _ in 0..5 {
res_ct = evaluation_key.smart_add_parallelized(&mut res_ct, &mut ct);
}
let res: u128 = client_key.decrypt(&res_ct);
assert_eq!(modulus - 6, res);
}
}
}