concision_init/distr/
trunc.rs

1/*
2    Appellation: trunc <distr>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use num::traits::Float;
6use rand::{Rng, RngCore};
7use rand_distr::{Distribution, Normal, StandardNormal};
8
9/// A truncated normal distribution is similar to a [normal](rand_distr::Normal) [distribution](rand_distr::Distribution), however,
10/// any generated value over two standard deviations from the mean is discarded and re-generated.
11#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
12pub struct TruncatedNormal<T>
13where
14    StandardNormal: Distribution<T>,
15{
16    mean: T,
17    std: T,
18}
19
20impl<T> TruncatedNormal<T>
21where
22    T: Float,
23    StandardNormal: Distribution<T>,
24{
25    /// create a new [`TruncatedNormal`] distribution with the given mean and standard
26    /// deviation; both of which are type `T`.
27    pub const fn new(mean: T, std: T) -> crate::Result<Self> {
28        Ok(Self { mean, std })
29    }
30    /// compute the boundary of the truncated normal distribution
31    /// which is two standard deviations from the mean:
32    /// $$
33    /// \text{boundary} = \mu + 2\sigma
34    /// $$
35    pub(crate) fn boundary(&self) -> T {
36        self.mean() + self.std_dev() * T::from(2).unwrap()
37    }
38
39    pub(crate) fn score(&self, x: T) -> T {
40        self.mean() - self.std_dev() * x
41    }
42
43    pub fn distr(&self) -> Normal<T> {
44        Normal::new(self.mean(), self.std_dev()).unwrap()
45    }
46
47    pub fn mean(&self) -> T {
48        self.mean
49    }
50
51    pub fn std_dev(&self) -> T {
52        self.std
53    }
54}
55
56impl<T> Distribution<T> for TruncatedNormal<T>
57where
58    T: Float,
59    StandardNormal: Distribution<T>,
60{
61    fn sample<R>(&self, rng: &mut R) -> T
62    where
63        R: RngCore + ?Sized,
64    {
65        let bnd = self.boundary();
66        let mut x = self.score(rng.sample(StandardNormal));
67        // if x is outside of the boundary, re-sample
68        while x < -bnd || x > bnd {
69            x = self.score(rng.sample(StandardNormal));
70        }
71        x
72    }
73}
74
75impl<T> From<Normal<T>> for TruncatedNormal<T>
76where
77    T: Float,
78    StandardNormal: Distribution<T>,
79{
80    fn from(normal: Normal<T>) -> Self {
81        Self {
82            mean: normal.mean(),
83            std: normal.std_dev(),
84        }
85    }
86}