concision_core/init/distr/
trunc.rs

1/*
2    Appellation: trunc <distr>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use num::traits::Float;
6use rand::Rng;
7use rand_distr::{Distribution, Normal, NormalError, StandardNormal};
8
9/// A truncated normal distribution is similar to a [normal](rand_distr::Normal) [distribution](rand_distr::Distribution), however,
10/// any generated value over two standard deviations from the mean is discarded and re-generated.
11#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
12pub struct TruncatedNormal<F>
13where
14    StandardNormal: Distribution<F>,
15{
16    mean: F,
17    std: F,
18}
19
20impl<F> TruncatedNormal<F>
21where
22    F: Float,
23    StandardNormal: Distribution<F>,
24{
25    /// Create a new truncated normal distribution with a given mean and standard deviation
26    pub fn new(mean: F, std: F) -> Result<Self, NormalError> {
27        Ok(Self { mean, std })
28    }
29
30    pub(crate) fn boundary(&self) -> F {
31        self.mean() + self.std_dev() * F::from(2).unwrap()
32    }
33
34    pub(crate) fn score(&self, x: F) -> F {
35        self.mean() - self.std_dev() * x
36    }
37
38    pub fn distr(&self) -> Normal<F> {
39        Normal::new(self.mean(), self.std_dev()).unwrap()
40    }
41
42    pub fn mean(&self) -> F {
43        self.mean
44    }
45
46    pub fn std_dev(&self) -> F {
47        self.std
48    }
49}
50
51impl<F> Distribution<F> for TruncatedNormal<F>
52where
53    F: Float,
54    StandardNormal: Distribution<F>,
55{
56    fn sample<R>(&self, rng: &mut R) -> F
57    where
58        R: Rng + ?Sized,
59    {
60        let bnd = self.boundary();
61        let mut x = self.score(rng.sample(StandardNormal));
62        // if x is outside of the boundary, re-sample
63        while x < -bnd || x > bnd {
64            x = self.score(rng.sample(StandardNormal));
65        }
66        x
67    }
68}
69
70impl<F> From<Normal<F>> for TruncatedNormal<F>
71where
72    F: Float,
73    StandardNormal: Distribution<F>,
74{
75    fn from(normal: Normal<F>) -> Self {
76        Self {
77            mean: normal.mean(),
78            std: normal.std_dev(),
79        }
80    }
81}