use super::base::{FheUint, FheUintId};
use crate::backward_compatibility::integers::{
InnerSquashedNoiseRadixCiphertextVersionedOwned, SquashedNoiseFheUintVersions,
};
use crate::core_crypto::commons::numeric::UnsignedNumeric;
use crate::high_level_api::details::MaybeCloned;
use crate::high_level_api::errors::UninitializedNoiseSquashing;
use crate::high_level_api::global_state::{self, with_internal_keys};
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::{
with_cuda_internal_keys, with_thread_local_cuda_streams_for_gpu_indexes,
};
use crate::high_level_api::keys::InternalServerKey;
use crate::high_level_api::traits::{FheDecrypt, SquashNoise, Tagged};
use crate::high_level_api::SquashedNoiseCiphertextState;
use crate::integer::block_decomposition::RecomposableFrom;
#[cfg(feature = "gpu")]
use crate::integer::gpu::ciphertext::squashed_noise::CudaSquashedNoiseRadixCiphertext;
use crate::named::Named;
use crate::{ClientKey, Device, Tag};
use serde::{Deserializer, Serializer};
use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned};
pub(in crate::high_level_api) enum InnerSquashedNoiseRadixCiphertext {
Cpu(crate::integer::ciphertext::SquashedNoiseRadixCiphertext),
#[cfg(feature = "gpu")]
Cuda(CudaSquashedNoiseRadixCiphertext),
}
impl Clone for InnerSquashedNoiseRadixCiphertext {
fn clone(&self) -> Self {
match self {
Self::Cpu(inner) => Self::Cpu(inner.clone()),
#[cfg(feature = "gpu")]
Self::Cuda(inner) => {
with_thread_local_cuda_streams_for_gpu_indexes(inner.gpu_indexes(), |streams| {
Self::Cuda(inner.duplicate(streams))
})
}
}
}
}
impl serde::Serialize for InnerSquashedNoiseRadixCiphertext {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.on_cpu().serialize(serializer)
}
}
impl<'de> serde::Deserialize<'de> for InnerSquashedNoiseRadixCiphertext {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let mut deserialized = Self::Cpu(
crate::integer::ciphertext::SquashedNoiseRadixCiphertext::deserialize(deserializer)?,
);
deserialized.move_to_device_of_server_key_if_set();
Ok(deserialized)
}
}
#[derive(serde::Serialize, serde::Deserialize)]
#[cfg_attr(dylint_lib = "tfhe_lints", allow(serialize_without_versionize))]
pub(crate) struct InnerSquashedNoiseRadixCiphertextVersionOwned(
<crate::integer::ciphertext::SquashedNoiseRadixCiphertext as VersionizeOwned>::VersionedOwned,
);
impl Versionize for InnerSquashedNoiseRadixCiphertext {
type Versioned<'vers> = InnerSquashedNoiseRadixCiphertextVersionedOwned;
fn versionize(&self) -> Self::Versioned<'_> {
let data = self.on_cpu();
let versioned = data.into_owned().versionize_owned();
InnerSquashedNoiseRadixCiphertextVersionedOwned::V0(
InnerSquashedNoiseRadixCiphertextVersionOwned(versioned),
)
}
}
impl VersionizeOwned for InnerSquashedNoiseRadixCiphertext {
type VersionedOwned = InnerSquashedNoiseRadixCiphertextVersionedOwned;
fn versionize_owned(self) -> Self::VersionedOwned {
let cpu_data = self.on_cpu();
InnerSquashedNoiseRadixCiphertextVersionedOwned::V0(
InnerSquashedNoiseRadixCiphertextVersionOwned(cpu_data.into_owned().versionize_owned()),
)
}
}
impl Unversionize for InnerSquashedNoiseRadixCiphertext {
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
match versioned {
InnerSquashedNoiseRadixCiphertextVersionedOwned::V0(v0) => {
let mut unversioned = Self::Cpu(
crate::integer::ciphertext::SquashedNoiseRadixCiphertext::unversionize(v0.0)?,
);
unversioned.move_to_device_of_server_key_if_set();
Ok(unversioned)
}
}
}
}
impl InnerSquashedNoiseRadixCiphertext {
pub(crate) fn on_cpu(
&self,
) -> MaybeCloned<'_, crate::integer::ciphertext::SquashedNoiseRadixCiphertext> {
match self {
Self::Cpu(ct) => MaybeCloned::Borrowed(ct),
#[cfg(feature = "gpu")]
Self::Cuda(ct) => {
with_thread_local_cuda_streams_for_gpu_indexes(ct.gpu_indexes(), |streams| {
MaybeCloned::Cloned(ct.to_squashed_noise_radix_ciphertext(streams))
})
}
}
}
fn current_device(&self) -> crate::Device {
match self {
Self::Cpu(_) => crate::Device::Cpu,
#[cfg(feature = "gpu")]
Self::Cuda(_) => crate::Device::CudaGpu,
}
}
#[allow(clippy::needless_pass_by_ref_mut)]
fn move_to_device(&mut self, target_device: Device) {
let current_device = self.current_device();
if current_device == target_device {
#[cfg(feature = "gpu")]
if let Self::Cuda(cuda_ct) = self {
with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
if cuda_ct.gpu_indexes() != streams.gpu_indexes() {
*cuda_ct = cuda_ct.duplicate(streams);
}
})
}
return;
}
let cpu_ct = self.on_cpu();
match target_device {
Device::Cpu => {
let _ = cpu_ct;
}
#[cfg(feature = "gpu")]
Device::CudaGpu => {
let new_inner = with_cuda_internal_keys(|keys| {
let streams = &keys.streams;
CudaSquashedNoiseRadixCiphertext::from_squashed_noise_radix_ciphertext(
&cpu_ct, streams,
)
});
*self = Self::Cuda(new_inner);
}
#[cfg(feature = "hpu")]
Device::Hpu => {
panic!("HPU does not support compression");
}
}
}
#[inline]
pub(crate) fn move_to_device_of_server_key_if_set(&mut self) {
if let Some(device) = global_state::device_of_internal_keys() {
self.move_to_device(device);
}
}
}
#[derive(Clone, serde::Deserialize, serde::Serialize, Versionize)]
#[versionize(SquashedNoiseFheUintVersions)]
pub struct SquashedNoiseFheUint {
pub(in crate::high_level_api) inner: InnerSquashedNoiseRadixCiphertext,
pub(in crate::high_level_api) state: SquashedNoiseCiphertextState,
tag: Tag,
}
impl Named for SquashedNoiseFheUint {
const NAME: &'static str = "high_level_api::SquashedNoiseFheUint";
}
impl SquashedNoiseFheUint {
pub(in crate::high_level_api) fn new(
inner: InnerSquashedNoiseRadixCiphertext,
state: SquashedNoiseCiphertextState,
tag: Tag,
) -> Self {
Self { inner, state, tag }
}
pub fn underlying_squashed_noise_ciphertext(
&self,
) -> MaybeCloned<'_, crate::integer::ciphertext::SquashedNoiseRadixCiphertext> {
self.inner.on_cpu()
}
pub fn num_bits(&self) -> usize {
match &self.inner {
InnerSquashedNoiseRadixCiphertext::Cpu(cpu_ct) => {
cpu_ct.original_block_count
* cpu_ct.packed_blocks[0].message_modulus().0.ilog2() as usize
}
#[cfg(feature = "gpu")]
InnerSquashedNoiseRadixCiphertext::Cuda(gpu_ct) => {
gpu_ct.original_block_count
* gpu_ct
.info
.blocks
.first()
.unwrap()
.message_modulus
.0
.ilog2() as usize
}
}
}
}
impl<Clear> FheDecrypt<Clear> for SquashedNoiseFheUint
where
Clear: RecomposableFrom<u128> + UnsignedNumeric,
{
fn decrypt(&self, key: &ClientKey) -> Clear {
let noise_squashing_private_key = key.private_noise_squashing_decryption_key(self.state);
noise_squashing_private_key
.decrypt_radix(&self.inner.on_cpu())
.unwrap()
}
}
impl Tagged for SquashedNoiseFheUint {
fn tag(&self) -> &Tag {
&self.tag
}
fn tag_mut(&mut self) -> &mut Tag {
&mut self.tag
}
}
impl<Id: FheUintId> SquashNoise for FheUint<Id> {
type Output = SquashedNoiseFheUint;
fn squash_noise(&self) -> crate::Result<Self::Output> {
with_internal_keys(|keys| match keys {
InternalServerKey::Cpu(server_key) => {
let noise_squashing_key = server_key
.key
.noise_squashing_key
.as_ref()
.ok_or(UninitializedNoiseSquashing)?;
Ok(SquashedNoiseFheUint {
inner: InnerSquashedNoiseRadixCiphertext::Cpu(
noise_squashing_key.squash_radix_ciphertext_noise(
server_key.key.pbs_key(),
&self.ciphertext.on_cpu(),
)?,
),
state: SquashedNoiseCiphertextState::Normal,
tag: server_key.tag.clone(),
})
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => {
let streams = &cuda_key.streams;
let noise_squashing_key = cuda_key
.key
.noise_squashing_key
.as_ref()
.ok_or(UninitializedNoiseSquashing)?;
let cuda_squashed_ct = noise_squashing_key.squash_radix_ciphertext_noise(
cuda_key.pbs_key(),
&self.ciphertext.on_gpu(streams).ciphertext,
streams,
)?;
let squashed_ct = cuda_squashed_ct.to_squashed_noise_radix_ciphertext(streams);
Ok(SquashedNoiseFheUint {
inner: InnerSquashedNoiseRadixCiphertext::Cpu(squashed_ct),
state: SquashedNoiseCiphertextState::Normal,
tag: cuda_key.tag.clone(),
})
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
Err(crate::error!("Hpu devices do not support noise squashing"))
}
})
}
}