concrete_commons/
dispersion.rs

1//! Noise distribution
2//!
3//! When dealing with noise, we tend to use different representation for the same value. In
4//! general, the noise is specified by the standard deviation of a gaussian distribution, which
5//! is of the form $\sigma = 2^p$, with $p$ a negative integer. Depending on the use case though,
6//! we rely on different representations for this quantity:
7//!
8//! + $\sigma$ can be encoded in the [`StandardDev`] type.
9//! + $p$ can be encoded in the [`LogStandardDev`] type.
10//! + $\sigma^2$ can be encoded in the [`Variance`] type.
11//!
12//! In any of those cases, the corresponding type implements the `DispersionParameter` trait,
13//! which makes if possible to use any of those representations generically when noise must be
14//! defined.
15
16#[cfg(feature = "serde_serialize")]
17use serde::{Deserialize, Serialize};
18
19use crate::numeric::UnsignedInteger;
20
21/// A trait for types representing distribution parameters, for a given unsigned integer type.
22//  Warning:
23//  DispersionParameter type should ONLY wrap a single native type.
24//  As long as Variance wraps a native type (f64) it is ok to derive it from Copy instead of
25//  Clone because f64 is itself Copy and stored in register.
26pub trait DispersionParameter: Copy {
27    /// Returns the standard deviation of the distribution, i.e. $\sigma = 2^p$.
28    fn get_standard_dev(&self) -> f64;
29    /// Returns the variance of the distribution, i.e. $\sigma^2 = 2^{2p}$.
30    fn get_variance(&self) -> f64;
31    /// Returns base 2 logarithm of the standard deviation of the distribution, i.e.
32    /// $\log_2(\sigma)=p$
33    fn get_log_standard_dev(&self) -> f64;
34    /// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $2^{q-p}$.
35    fn get_modular_standard_dev<Uint>(&self) -> f64
36    where
37        Uint: UnsignedInteger;
38    /// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $2^{2(q-p)}$.
39    fn get_modular_variance<Uint>(&self) -> f64
40    where
41        Uint: UnsignedInteger;
42    /// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $q-p$.
43    fn get_modular_log_standard_dev<Uint>(&self) -> f64
44    where
45        Uint: UnsignedInteger;
46}
47
48/// A distribution parameter that uses the base-2 logarithm of the standard deviation as
49/// representation.
50///
51/// # Example:
52///
53/// ```
54/// use concrete_commons::dispersion::{DispersionParameter, LogStandardDev};
55/// let params = LogStandardDev::from_log_standard_dev(-25.);
56/// assert_eq!(params.get_standard_dev(), 2_f64.powf(-25.));
57/// assert_eq!(params.get_log_standard_dev(), -25.);
58/// assert_eq!(params.get_variance(), 2_f64.powf(-25.).powi(2));
59/// assert_eq!(
60///     params.get_modular_standard_dev::<u32>(),
61///     2_f64.powf(32. - 25.)
62/// );
63/// assert_eq!(params.get_modular_log_standard_dev::<u32>(), 32. - 25.);
64/// assert_eq!(
65///     params.get_modular_variance::<u32>(),
66///     2_f64.powf(32. - 25.).powi(2)
67/// );
68///
69/// let modular_params = LogStandardDev::from_modular_log_standard_dev::<u32>(22.);
70/// assert_eq!(modular_params.get_standard_dev(), 2_f64.powf(-10.));
71/// ```
72#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
73pub struct LogStandardDev(pub f64);
74
75impl LogStandardDev {
76    pub fn from_log_standard_dev(log_std: f64) -> LogStandardDev {
77        LogStandardDev(log_std)
78    }
79
80    pub fn from_modular_log_standard_dev<Uint>(log_std: f64) -> LogStandardDev
81    where
82        Uint: UnsignedInteger,
83    {
84        LogStandardDev(log_std - Uint::BITS as f64)
85    }
86}
87
88impl DispersionParameter for LogStandardDev {
89    fn get_standard_dev(&self) -> f64 {
90        f64::powf(2., self.0)
91    }
92    fn get_variance(&self) -> f64 {
93        f64::powf(2., self.0 * 2.)
94    }
95    fn get_log_standard_dev(&self) -> f64 {
96        self.0
97    }
98    fn get_modular_standard_dev<Uint>(&self) -> f64
99    where
100        Uint: UnsignedInteger,
101    {
102        f64::powf(2., Uint::BITS as f64 + self.0)
103    }
104    fn get_modular_variance<Uint>(&self) -> f64
105    where
106        Uint: UnsignedInteger,
107    {
108        f64::powf(2., (Uint::BITS as f64 + self.0) * 2.)
109    }
110    fn get_modular_log_standard_dev<Uint>(&self) -> f64
111    where
112        Uint: UnsignedInteger,
113    {
114        Uint::BITS as f64 + self.0
115    }
116}
117
118/// A distribution parameter that uses the standard deviation as representation.
119///
120/// # Example:
121///
122/// ```
123/// use concrete_commons::dispersion::{DispersionParameter, StandardDev};
124/// let params = StandardDev::from_standard_dev(2_f64.powf(-25.));
125/// assert_eq!(params.get_standard_dev(), 2_f64.powf(-25.));
126/// assert_eq!(params.get_log_standard_dev(), -25.);
127/// assert_eq!(params.get_variance(), 2_f64.powf(-25.).powi(2));
128/// assert_eq!(
129///     params.get_modular_standard_dev::<u32>(),
130///     2_f64.powf(32. - 25.)
131/// );
132/// assert_eq!(params.get_modular_log_standard_dev::<u32>(), 32. - 25.);
133/// assert_eq!(
134///     params.get_modular_variance::<u32>(),
135///     2_f64.powf(32. - 25.).powi(2)
136/// );
137/// ```
138#[cfg_attr(feature = "serde_serialize", derive(Serialize, Deserialize))]
139#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
140pub struct StandardDev(pub f64);
141
142impl StandardDev {
143    pub fn from_standard_dev(std: f64) -> StandardDev {
144        StandardDev(std)
145    }
146
147    pub fn from_modular_standard_dev<Uint>(std: f64) -> StandardDev
148    where
149        Uint: UnsignedInteger,
150    {
151        StandardDev(std / 2_f64.powf(Uint::BITS as f64))
152    }
153}
154
155impl DispersionParameter for StandardDev {
156    fn get_standard_dev(&self) -> f64 {
157        self.0
158    }
159    fn get_variance(&self) -> f64 {
160        self.0.powi(2)
161    }
162    fn get_log_standard_dev(&self) -> f64 {
163        self.0.log2()
164    }
165    fn get_modular_standard_dev<Uint>(&self) -> f64
166    where
167        Uint: UnsignedInteger,
168    {
169        2_f64.powf(Uint::BITS as f64 + self.0.log2())
170    }
171    fn get_modular_variance<Uint>(&self) -> f64
172    where
173        Uint: UnsignedInteger,
174    {
175        2_f64.powf(2. * (Uint::BITS as f64 + self.0.log2()))
176    }
177    fn get_modular_log_standard_dev<Uint>(&self) -> f64
178    where
179        Uint: UnsignedInteger,
180    {
181        Uint::BITS as f64 + self.0.log2()
182    }
183}
184
185/// A distribution parameter that uses the variance as representation
186///
187/// # Example:
188///
189/// ```
190/// use concrete_commons::dispersion::{DispersionParameter, Variance};
191/// let params = Variance::from_variance(2_f64.powi(-50));
192/// assert_eq!(params.get_standard_dev(), 2_f64.powf(-25.));
193/// assert_eq!(params.get_log_standard_dev(), -25.);
194/// assert_eq!(params.get_variance(), 2_f64.powf(-25.).powi(2));
195/// assert_eq!(
196///     params.get_modular_standard_dev::<u32>(),
197///     2_f64.powf(32. - 25.)
198/// );
199/// assert_eq!(params.get_modular_log_standard_dev::<u32>(), 32. - 25.);
200/// assert_eq!(
201///     params.get_modular_variance::<u32>(),
202///     2_f64.powf(32. - 25.).powi(2)
203/// );
204/// ```
205#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
206pub struct Variance(pub f64);
207
208impl Variance {
209    pub fn from_variance(var: f64) -> Variance {
210        Variance(var)
211    }
212
213    pub fn from_modular_variance<Uint>(var: f64) -> Variance
214    where
215        Uint: UnsignedInteger,
216    {
217        Variance(var / 2_f64.powf(Uint::BITS as f64 * 2.))
218    }
219}
220
221impl DispersionParameter for Variance {
222    fn get_standard_dev(&self) -> f64 {
223        self.0.sqrt()
224    }
225    fn get_variance(&self) -> f64 {
226        self.0
227    }
228    fn get_log_standard_dev(&self) -> f64 {
229        self.0.sqrt().log2()
230    }
231    fn get_modular_standard_dev<Uint>(&self) -> f64
232    where
233        Uint: UnsignedInteger,
234    {
235        2_f64.powf(Uint::BITS as f64 + self.0.sqrt().log2())
236    }
237    fn get_modular_variance<Uint>(&self) -> f64
238    where
239        Uint: UnsignedInteger,
240    {
241        2_f64.powf(2. * (Uint::BITS as f64 + self.0.sqrt().log2()))
242    }
243    fn get_modular_log_standard_dev<Uint>(&self) -> f64
244    where
245        Uint: UnsignedInteger,
246    {
247        Uint::BITS as f64 + self.0.sqrt().log2()
248    }
249}