use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;
use crate::core_crypto::backward_compatibility::commons::dispersion::{
StandardDevVersions, VarianceVersions,
};
pub trait DispersionParameter: Copy {
fn get_standard_dev(&self) -> StandardDev;
fn get_variance(&self) -> Variance;
fn get_log_standard_dev(&self) -> LogStandardDev;
fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev;
fn get_modular_variance(&self, modulus: f64) -> ModularVariance;
fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev;
}
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
pub struct LogStandardDev(pub f64);
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
pub struct ModularLogStandardDev {
pub value: f64,
pub modulus: f64,
}
impl LogStandardDev {
pub fn from_log_standard_dev(log_std: f64) -> Self {
Self(log_std)
}
pub fn from_modular_log_standard_dev(log_std: f64, log2_modulus: u32) -> Self {
Self(log_std - log2_modulus as f64)
}
}
impl DispersionParameter for LogStandardDev {
fn get_standard_dev(&self) -> StandardDev {
StandardDev(f64::powf(2., self.0))
}
fn get_variance(&self) -> Variance {
Variance(f64::powf(2., self.0 * 2.))
}
fn get_log_standard_dev(&self) -> Self {
Self(self.0)
}
fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev {
ModularStandardDev {
value: 2_f64.powf(self.0) * modulus,
modulus,
}
}
fn get_modular_variance(&self, modulus: f64) -> ModularVariance {
let std_dev = 2_f64.powf(self.0) * modulus;
ModularVariance {
value: std_dev * std_dev,
modulus,
}
}
fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev {
ModularLogStandardDev {
value: modulus.log2() + self.0,
modulus,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Serialize, Deserialize, Versionize)]
#[versionize(StandardDevVersions)]
pub struct StandardDev(pub f64);
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
pub struct ModularStandardDev {
pub value: f64,
pub modulus: f64,
}
impl StandardDev {
pub fn from_standard_dev(std: f64) -> Self {
Self(std)
}
pub fn from_modular_standard_dev(std: f64, log2_modulus: u32) -> Self {
Self(std / 2_f64.powf(log2_modulus as f64))
}
}
impl DispersionParameter for StandardDev {
fn get_standard_dev(&self) -> Self {
Self(self.0)
}
fn get_variance(&self) -> Variance {
Variance(self.0.powi(2))
}
fn get_log_standard_dev(&self) -> LogStandardDev {
LogStandardDev(self.0.log2())
}
fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev {
ModularStandardDev {
value: self.0 * modulus,
modulus,
}
}
fn get_modular_variance(&self, modulus: f64) -> ModularVariance {
let std_dev = self.0 * modulus;
ModularVariance {
value: std_dev * std_dev,
modulus,
}
}
fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev {
ModularLogStandardDev {
value: modulus.log2() + self.0.log2(),
modulus,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Serialize, Deserialize, Versionize)]
#[versionize(VarianceVersions)]
pub struct Variance(pub f64);
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
pub struct ModularVariance {
pub value: f64,
pub modulus: f64,
}
impl Variance {
pub fn from_variance(var: f64) -> Self {
Self(var)
}
pub fn from_modular_variance(var: f64, modulus: f64) -> Self {
Self(var / (modulus * modulus))
}
}
impl DispersionParameter for Variance {
fn get_standard_dev(&self) -> StandardDev {
StandardDev(self.0.sqrt())
}
fn get_variance(&self) -> Self {
Self(self.0)
}
fn get_log_standard_dev(&self) -> LogStandardDev {
LogStandardDev(self.0.sqrt().log2())
}
fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev {
ModularStandardDev {
value: self.0.sqrt() * modulus,
modulus,
}
}
fn get_modular_variance(&self, modulus: f64) -> ModularVariance {
ModularVariance {
value: self.0 * modulus * modulus,
modulus,
}
}
fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev {
ModularLogStandardDev {
value: modulus.log2() + self.0.sqrt().log2(),
modulus,
}
}
}