use std::ops::RangeInclusive;
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;
pub use super::ciphertext_modulus::CiphertextModulus;
use super::traits::CastInto;
use crate::core_crypto::backward_compatibility::commons::parameters::*;
#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize, Versionize)]
#[versionize(PlaintextCountVersions)]
pub struct PlaintextCount(pub usize);
#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize, Versionize)]
#[versionize(CleartextCountVersions)]
pub struct CleartextCount(pub usize);
#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize, Versionize)]
#[versionize(CiphertextCountVersions)]
pub struct CiphertextCount(pub usize);
#[derive(
Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Debug, Serialize, Deserialize, Versionize,
)]
#[versionize(LweCiphertextCountVersions)]
pub struct LweCiphertextCount(pub usize);
#[cfg(feature = "gpu")]
#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize, Versionize)]
#[versionize(LweCiphertextIndexVersions)]
pub struct LweCiphertextIndex(pub usize);
#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize, Versionize)]
#[versionize(GlweCiphertextCountVersions)]
pub struct GlweCiphertextCount(pub usize);
#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize, Versionize)]
#[versionize(GswCiphertextCountVersions)]
pub struct GswCiphertextCount(pub usize);
#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize, Versionize)]
#[versionize(GgswCiphertextCountVersions)]
pub struct GgswCiphertextCount(pub usize);
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone, Serialize, Deserialize, Versionize,
)]
#[versionize(LweSizeVersions)]
pub struct LweSize(pub usize);
impl LweSize {
pub fn to_lwe_dimension(&self) -> LweDimension {
LweDimension(self.0.saturating_sub(1))
}
}
#[derive(
Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize, Versionize,
)]
#[versionize(LweDimensionVersions)]
pub struct LweDimension(pub usize);
impl LweDimension {
pub fn to_lwe_size(&self) -> LweSize {
LweSize(self.0.saturating_add(1))
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize, Versionize)]
#[versionize(LwePublicKeyZeroEncryptionCountVersions)]
pub struct LwePublicKeyZeroEncryptionCount(pub usize);
#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize, Versionize)]
#[versionize(LweMaskCountVersions)]
pub struct LweMaskCount(pub usize);
#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize, Versionize)]
#[versionize(LweBodyCountVersions)]
pub struct LweBodyCount(pub usize);
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone, Serialize, Deserialize, Versionize,
)]
#[versionize(GlweSizeVersions)]
pub struct GlweSize(pub usize);
impl GlweSize {
pub fn to_glwe_dimension(&self) -> GlweDimension {
GlweDimension(self.0 - 1)
}
}
#[derive(
Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize, Versionize,
)]
#[versionize(GlweDimensionVersions)]
pub struct GlweDimension(pub usize);
impl GlweDimension {
pub fn to_glwe_size(&self) -> GlweSize {
GlweSize(self.0 + 1)
}
pub const fn to_equivalent_lwe_dimension(self, poly_size: PolynomialSize) -> LweDimension {
LweDimension(self.0 * poly_size.0)
}
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Versionize, Hash,
)]
#[versionize(PolynomialSizeVersions)]
pub struct PolynomialSize(pub usize);
impl PolynomialSize {
pub const fn log2(&self) -> PolynomialSizeLog {
PolynomialSizeLog(self.0.ilog2() as usize)
}
pub fn to_fourier_polynomial_size(&self) -> FourierPolynomialSize {
assert_eq!(
self.0 % 2,
0,
"Cannot convert a PolynomialSize that is not a multiple of 2 to FourierPolynomialSize"
);
FourierPolynomialSize(self.0 / 2)
}
pub const fn to_blind_rotation_input_modulus_log(&self) -> CiphertextModulusLog {
CiphertextModulusLog(self.log2().0 + 1)
}
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Versionize,
)]
#[versionize(FourierPolynomialSizeVersions)]
pub struct FourierPolynomialSize(pub usize);
impl FourierPolynomialSize {
pub fn to_standard_polynomial_size(&self) -> PolynomialSize {
PolynomialSize(self.0 * 2)
}
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Versionize,
)]
#[versionize(PolynomialSizeLogVersions)]
pub struct PolynomialSizeLog(pub usize);
impl PolynomialSizeLog {
pub fn to_polynomial_size(&self) -> PolynomialSize {
PolynomialSize(1 << self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Versionize)]
#[versionize(PolynomialCountVersions)]
pub struct PolynomialCount(pub usize);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Versionize)]
#[versionize(MonomialDegreeVersions)]
pub struct MonomialDegree(pub usize);
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize, Versionize)]
#[versionize(DecompositionBaseLogVersions)]
pub struct DecompositionBaseLog(pub usize);
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize, Versionize)]
#[versionize(DecompositionLevelCountVersions)]
pub struct DecompositionLevelCount(pub usize);
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize, Versionize)]
#[versionize(LutCountLogVersions)]
pub struct LutCountLog(pub usize);
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize, Versionize)]
#[versionize(ModulusSwitchOffsetVersions)]
pub struct ModulusSwitchOffset(pub usize);
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize, Versionize)]
#[versionize(DeltaLogVersions)]
pub struct DeltaLog(pub usize);
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize, Versionize)]
#[versionize(ExtractedBitsCountVersions)]
pub struct ExtractedBitsCount(pub usize);
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize, Versionize)]
#[versionize(FunctionalPackingKeyswitchKeyCountVersions)]
pub struct FunctionalPackingKeyswitchKeyCount(pub usize);
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize, Versionize)]
#[versionize(CiphertextModulusLogVersions)]
pub struct CiphertextModulusLog(pub usize);
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize, Versionize)]
#[versionize(MessageModulusLogVersions)]
pub struct MessageModulusLog(pub usize);
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize, Versionize)]
#[versionize(ThreadCountVersions)]
pub struct ThreadCount(pub usize);
#[derive(Debug, PartialEq, Eq, Copy, Clone, Hash, Serialize, Deserialize, Versionize)]
#[versionize(LweBskGroupingFactorVersions)]
pub struct LweBskGroupingFactor(pub usize);
impl LweBskGroupingFactor {
pub fn ggsw_per_multi_bit_element(&self) -> GgswPerLweMultiBitBskElement {
GgswPerLweMultiBitBskElement(1 << self.0)
}
}
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize, Versionize)]
#[versionize(GgswPerLweMultiBitBskElementVersions)]
pub struct GgswPerLweMultiBitBskElement(pub usize);
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize, Versionize)]
#[versionize(EncryptionKeyChoiceVersions)]
pub enum EncryptionKeyChoice {
Big,
Small,
}
impl EncryptionKeyChoice {
pub const fn into_pbs_order(self) -> PBSOrder {
match self {
Self::Big => PBSOrder::KeyswitchBootstrap,
Self::Small => PBSOrder::BootstrapKeyswitch,
}
}
}
impl From<EncryptionKeyChoice> for PBSOrder {
fn from(value: EncryptionKeyChoice) -> Self {
value.into_pbs_order()
}
}
impl From<PBSOrder> for EncryptionKeyChoice {
fn from(value: PBSOrder) -> Self {
match value {
PBSOrder::KeyswitchBootstrap => Self::Big,
PBSOrder::BootstrapKeyswitch => Self::Small,
}
}
}
impl From<EncryptionKeyChoice> for usize {
fn from(value: EncryptionKeyChoice) -> Self {
match value {
EncryptionKeyChoice::Big => 0,
EncryptionKeyChoice::Small => 1,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Versionize)]
#[versionize(PBSOrderVersions)]
pub enum PBSOrder {
KeyswitchBootstrap = 0,
BootstrapKeyswitch = 1,
}
pub use crate::core_crypto::commons::math::random::DynamicDistribution;
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub struct EncryptionMaskSampleCount(pub usize);
impl EncryptionMaskSampleCount {
pub(crate) fn to_mask_byte_count(
self,
mask_byte_per_scalar: EncryptionMaskByteCount,
) -> EncryptionMaskByteCount {
EncryptionMaskByteCount(self.0 * mask_byte_per_scalar.0)
}
}
impl std::ops::Mul<usize> for EncryptionMaskSampleCount {
type Output = Self;
fn mul(self, rhs: usize) -> Self::Output {
Self(self.0 * rhs)
}
}
impl std::ops::Mul<EncryptionMaskSampleCount> for usize {
type Output = EncryptionMaskSampleCount;
fn mul(self, rhs: EncryptionMaskSampleCount) -> Self::Output {
EncryptionMaskSampleCount(self * rhs.0)
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub struct EncryptionMaskByteCount(pub usize);
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub struct EncryptionNoiseSampleCount(pub usize);
impl EncryptionNoiseSampleCount {
pub(crate) fn to_noise_byte_count(
self,
noise_byte_per_scalar: EncryptionNoiseByteCount,
) -> EncryptionNoiseByteCount {
EncryptionNoiseByteCount(self.0 * noise_byte_per_scalar.0)
}
}
impl std::ops::Mul<usize> for EncryptionNoiseSampleCount {
type Output = Self;
fn mul(self, rhs: usize) -> Self::Output {
Self(self.0 * rhs)
}
}
impl std::ops::Mul<EncryptionNoiseSampleCount> for usize {
type Output = EncryptionNoiseSampleCount;
fn mul(self, rhs: EncryptionNoiseSampleCount) -> Self::Output {
EncryptionNoiseSampleCount(self * rhs.0)
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub struct EncryptionNoiseByteCount(pub usize);
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
#[versionize(RSigmaFactorVersions)]
pub struct RSigmaFactor(pub f64);
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
#[versionize(NoiseEstimationMeasureBoundVersions)]
pub struct NoiseEstimationMeasureBound(pub f64);
#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Versionize)]
#[versionize(ChunkSizeVersions)]
pub struct ChunkSize(pub usize);
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
#[versionize(NormalizedHammingWeightBoundVersions)]
pub struct NormalizedHammingWeightBound(f64);
impl NormalizedHammingWeightBound {
pub fn new(pmax: f64) -> Option<Self> {
if 0.5 < pmax && pmax <= 1.0 {
Some(Self(pmax))
} else {
None
}
}
pub fn get(self) -> f64 {
self.0
}
pub fn range(self, num_bits: usize) -> RangeInclusive<u128> {
RangeInclusive::new(
((1.0 - self.0) * num_bits as f64) as u128,
(self.0 * num_bits as f64) as u128,
)
}
#[allow(clippy::result_unit_err)]
pub fn check_binary_slice<T>(self, binary_slice: &[T]) -> Result<(), ()>
where
T: Copy + CastInto<u128>,
{
let hamming_weight = binary_slice
.iter()
.copied()
.map(|bit| -> u128 { bit.cast_into() })
.sum::<u128>();
let bounds = self.range(binary_slice.len());
if bounds.contains(&hamming_weight) {
Ok(())
} else {
Err(())
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_bad_lwe_size() {
let lwe_size = LweSize(0);
let lwe_dim = lwe_size.to_lwe_dimension();
assert_eq!(lwe_dim.0, 0);
}
#[test]
fn test_bad_lwe_dimension() {
let lwe_dim = LweDimension(usize::MAX);
let lwe_size = lwe_dim.to_lwe_size();
assert_eq!(lwe_size.0, usize::MAX);
}
}