concision_init/distr/
trunc.rs

1/*
2    Appellation: trunc <distr>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use num::traits::Float;
6use rand::{Rng, RngCore};
7use rand_distr::{Distribution, Normal, StandardNormal};
8
9/// The [`TruncatedNormal`] distribution is similar to the [`StandardNormal`] distribution,
10/// differing in that is computes a boundary equal to two standard deviations from the mean.
11/// More formally, the boundary is defined as:
12///
13/// ```math
14/// \text{boundary} = \mu + 2\sigma
15/// ```
16#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
17pub struct TruncatedNormal<T>
18where
19    StandardNormal: Distribution<T>,
20{
21    pub(crate) mean: T,
22    pub(crate) std: T,
23}
24
25impl<T> TruncatedNormal<T>
26where
27    T: Copy,
28    StandardNormal: Distribution<T>,
29{
30    /// create a new [`TruncatedNormal`] distribution with the given mean and standard
31    /// deviation; both of which are type `T`.
32    pub const fn new(mean: T, std: T) -> crate::Result<Self> {
33        Ok(Self { mean, std })
34    }
35    /// returns a copy of the mean for the distribution
36    pub const fn mean(&self) -> T {
37        self.mean
38    }
39    /// returns a copy of the standard deviation for the distribution
40    pub const fn std_dev(&self) -> T {
41        self.std
42    }
43    /// compute the boundary of the truncated normal distribution
44    /// which is two standard deviations from the mean:
45    /// $$
46    /// \text{boundary} = \mu + 2\sigma
47    /// $$
48    pub fn boundary(&self) -> T
49    where
50        T: Float,
51    {
52        self.mean() + self.std_dev() * T::from(2).unwrap()
53    }
54    /// returns a new [`Normal`] distribution instance created from the current mean and
55    /// standard deviation.
56    pub fn distr(&self) -> Normal<T>
57    where
58        T: Float,
59    {
60        Normal::new(self.mean(), self.std_dev()).unwrap()
61    }
62    /// compute the score of the distribution at point `x`. The score is calculated by
63    /// subtracing a scaled standard deviation from the mean:
64    /// $$
65    /// \text{score}(x) = \mu - \sigma \cdot x
66    /// $$
67    ///
68    /// where $\mu$ is the mean and $\sigma$ is the standard deviation.
69    pub fn score(&self, x: T) -> T
70    where
71        T: Float,
72    {
73        self.mean() - self.std_dev() * x
74    }
75}
76
77impl<T> Distribution<T> for TruncatedNormal<T>
78where
79    T: Float,
80    StandardNormal: Distribution<T>,
81{
82    fn sample<R>(&self, rng: &mut R) -> T
83    where
84        R: RngCore + ?Sized,
85    {
86        let bnd = self.boundary();
87        let mut x = self.score(rng.sample(StandardNormal));
88        // if x is outside of the boundary, re-sample
89        while x < -bnd || x > bnd {
90            x = self.score(rng.sample(StandardNormal));
91        }
92        x
93    }
94}
95
96impl<T> From<Normal<T>> for TruncatedNormal<T>
97where
98    T: Float,
99    StandardNormal: Distribution<T>,
100{
101    fn from(normal: Normal<T>) -> Self {
102        Self {
103            mean: normal.mean(),
104            std: normal.std_dev(),
105        }
106    }
107}