use std::marker::PhantomData;
use tfhe_versionable::Versionize;
use crate::backward_compatibility::integers::{
CompressedFheIntVersions, CompressedSignedRadixCiphertextVersions,
};
use crate::conformance::ParameterSetConformant;
use crate::core_crypto::prelude::SignedNumeric;
use crate::high_level_api::global_state;
use crate::high_level_api::integers::{FheInt, FheIntId};
use crate::high_level_api::keys::InternalServerKey;
use crate::high_level_api::re_randomization::ReRandomizationMetadata;
use crate::high_level_api::traits::Tagged;
use crate::integer::block_decomposition::DecomposableInto;
use crate::integer::ciphertext::{
CompressedModulusSwitchedRadixCiphertextConformanceParams,
CompressedModulusSwitchedSignedRadixCiphertext,
CompressedSignedRadixCiphertext as IntegerCompressedSignedRadixCiphertext,
};
use crate::named::Named;
use crate::prelude::FheTryEncrypt;
use crate::shortint::AtomicPatternParameters;
use crate::{ClientKey, ServerKey, Tag};
#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)]
#[versionize(CompressedFheIntVersions)]
pub struct CompressedFheInt<Id>
where
Id: FheIntId,
{
pub(in crate::high_level_api) ciphertext: CompressedSignedRadixCiphertext,
pub(in crate::high_level_api) id: Id,
pub(crate) tag: Tag,
}
impl<Id> Tagged for CompressedFheInt<Id>
where
Id: FheIntId,
{
fn tag(&self) -> &Tag {
&self.tag
}
fn tag_mut(&mut self) -> &mut Tag {
&mut self.tag
}
}
impl<Id> CompressedFheInt<Id>
where
Id: FheIntId,
{
pub(in crate::high_level_api::integers) fn new(
inner: CompressedSignedRadixCiphertext,
tag: Tag,
) -> Self {
Self {
ciphertext: inner,
id: Id::default(),
tag,
}
}
pub fn into_raw_parts(self) -> (CompressedSignedRadixCiphertext, Id, Tag) {
let Self {
ciphertext,
id,
tag,
} = self;
(ciphertext, id, tag)
}
pub fn from_raw_parts(ciphertext: CompressedSignedRadixCiphertext, id: Id, tag: Tag) -> Self {
Self {
ciphertext,
id,
tag,
}
}
}
impl<Id> CompressedFheInt<Id>
where
Id: FheIntId,
{
pub fn decompress(&self) -> FheInt<Id> {
let ciphertext = match &self.ciphertext {
CompressedSignedRadixCiphertext::Seeded(ct) => ct.decompress(),
CompressedSignedRadixCiphertext::ModulusSwitched(ct) => {
global_state::with_internal_keys(|keys| match keys {
InternalServerKey::Cpu(cpu_key) => {
cpu_key.pbs_key().decompress_signed_parallelized(ct)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("decompress() on FheInt is not supported on GPU, use a CompressedCiphertextList instead");
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_) => {
panic!("decompress() on FheInt is not supported on HPU devices");
}
})
}
};
FheInt::new(
ciphertext,
self.tag.clone(),
ReRandomizationMetadata::default(),
)
}
}
impl<Id, T> FheTryEncrypt<T, ClientKey> for CompressedFheInt<Id>
where
Id: FheIntId,
T: DecomposableInto<u64> + SignedNumeric,
{
type Error = crate::Error;
fn try_encrypt(value: T, key: &ClientKey) -> Result<Self, Self::Error> {
let integer_client_key = &key.key.key;
let inner = integer_client_key
.encrypt_signed_radix_compressed(value, Id::num_blocks(key.message_modulus()));
Ok(Self::new(
CompressedSignedRadixCiphertext::Seeded(inner),
key.tag.clone(),
))
}
}
#[derive(Copy, Clone)]
pub struct CompressedFheIntConformanceParams<Id: FheIntId> {
pub(crate) params: CompressedSignedRadixCiphertextConformanceParams,
pub(crate) id: PhantomData<Id>,
}
impl<Id: FheIntId, P: Into<AtomicPatternParameters>> From<P>
for CompressedFheIntConformanceParams<Id>
{
fn from(params: P) -> Self {
let params = params.into();
Self {
params: CompressedSignedRadixCiphertextConformanceParams(
CompressedModulusSwitchedRadixCiphertextConformanceParams {
shortint_params: params.to_compressed_modswitched_conformance_param(),
num_blocks_per_integer: Id::num_blocks(params.message_modulus()),
},
),
id: PhantomData,
}
}
}
impl<Id: FheIntId> From<&ServerKey> for CompressedFheIntConformanceParams<Id> {
fn from(sk: &ServerKey) -> Self {
Self {
params: CompressedSignedRadixCiphertextConformanceParams(
CompressedModulusSwitchedRadixCiphertextConformanceParams {
shortint_params: sk
.key
.pbs_key()
.key
.compressed_modswitched_conformance_params(),
num_blocks_per_integer: Id::num_blocks(sk.key.pbs_key().message_modulus()),
},
),
id: PhantomData,
}
}
}
impl<Id: FheIntId> ParameterSetConformant for CompressedFheInt<Id> {
type ParameterSet = CompressedFheIntConformanceParams<Id>;
fn is_conformant(&self, params: &CompressedFheIntConformanceParams<Id>) -> bool {
let Self {
ciphertext,
id: _,
tag: _,
} = self;
ciphertext.is_conformant(¶ms.params)
}
}
impl<Id: FheIntId> Named for CompressedFheInt<Id> {
const NAME: &'static str = "high_level_api::CompressedFheInt";
}
#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)]
#[versionize(CompressedSignedRadixCiphertextVersions)]
pub enum CompressedSignedRadixCiphertext {
Seeded(IntegerCompressedSignedRadixCiphertext),
ModulusSwitched(CompressedModulusSwitchedSignedRadixCiphertext),
}
#[derive(Copy, Clone)]
pub struct CompressedSignedRadixCiphertextConformanceParams(
pub(crate) CompressedModulusSwitchedRadixCiphertextConformanceParams,
);
impl ParameterSetConformant for CompressedSignedRadixCiphertext {
type ParameterSet = CompressedSignedRadixCiphertextConformanceParams;
fn is_conformant(&self, params: &CompressedSignedRadixCiphertextConformanceParams) -> bool {
match self {
Self::Seeded(ct) => ct.is_conformant(¶ms.0.into()),
Self::ModulusSwitched(ct) => ct.is_conformant(¶ms.0),
}
}
}
impl<Id> FheInt<Id>
where
Id: FheIntId,
{
pub fn compress(&self) -> CompressedFheInt<Id> {
global_state::with_internal_keys(|keys| match keys {
InternalServerKey::Cpu(cpu_key) => {
let a = cpu_key
.pbs_key()
.switch_modulus_and_compress_signed_parallelized(&self.ciphertext.on_cpu());
CompressedFheInt::new(
CompressedSignedRadixCiphertext::ModulusSwitched(a),
self.tag.clone(),
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("compress() on FheInt is not supported on GPU, use a CompressedCiphertextList instead");
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_) => {
panic!("compress() on FheInt is not supported on HPU devices");
}
})
}
}