concision_init/distr/trunc.rs
1/*
2 Appellation: trunc <distr>
3 Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use num::traits::Float;
6use rand::{Rng, RngCore};
7use rand_distr::{Distribution, Normal, StandardNormal};
8
9/// The [`TruncatedNormal`] distribution is similar to the [`StandardNormal`] distribution,
10/// differing in that is computes a boundary equal to two standard deviations from the mean.
11/// More formally, the boundary is defined as:
12///
13/// ```math
14/// \text{boundary} = \mu + 2\sigma
15/// ```
16#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
17pub struct TruncatedNormal<T>
18where
19 StandardNormal: Distribution<T>,
20{
21 pub(crate) mean: T,
22 pub(crate) std: T,
23}
24
25impl<T> TruncatedNormal<T>
26where
27 T: Copy,
28 StandardNormal: Distribution<T>,
29{
30 /// create a new [`TruncatedNormal`] distribution with the given mean and standard
31 /// deviation; both of which are type `T`.
32 pub const fn new(mean: T, std: T) -> crate::Result<Self> {
33 Ok(Self { mean, std })
34 }
35 /// returns a copy of the mean for the distribution
36 pub const fn mean(&self) -> T {
37 self.mean
38 }
39 /// returns a copy of the standard deviation for the distribution
40 pub const fn std_dev(&self) -> T {
41 self.std
42 }
43 /// compute the boundary of the truncated normal distribution
44 /// which is two standard deviations from the mean:
45 /// $$
46 /// \text{boundary} = \mu + 2\sigma
47 /// $$
48 pub fn boundary(&self) -> T
49 where
50 T: Float,
51 {
52 self.mean() + self.std_dev() * T::from(2).unwrap()
53 }
54 /// returns a new [`Normal`] distribution instance created from the current mean and
55 /// standard deviation.
56 pub fn distr(&self) -> Normal<T>
57 where
58 T: Float,
59 {
60 Normal::new(self.mean(), self.std_dev()).unwrap()
61 }
62 /// compute the score of the distribution at point `x`. The score is calculated by
63 /// subtracing a scaled standard deviation from the mean:
64 /// $$
65 /// \text{score}(x) = \mu - \sigma \cdot x
66 /// $$
67 ///
68 /// where $\mu$ is the mean and $\sigma$ is the standard deviation.
69 pub fn score(&self, x: T) -> T
70 where
71 T: Float,
72 {
73 self.mean() - self.std_dev() * x
74 }
75}
76
77impl<T> Distribution<T> for TruncatedNormal<T>
78where
79 T: Float,
80 StandardNormal: Distribution<T>,
81{
82 fn sample<R>(&self, rng: &mut R) -> T
83 where
84 R: RngCore + ?Sized,
85 {
86 let bnd = self.boundary();
87 let mut x = self.score(rng.sample(StandardNormal));
88 // if x is outside of the boundary, re-sample
89 while x < -bnd || x > bnd {
90 x = self.score(rng.sample(StandardNormal));
91 }
92 x
93 }
94}
95
96impl<T> From<Normal<T>> for TruncatedNormal<T>
97where
98 T: Float,
99 StandardNormal: Distribution<T>,
100{
101 fn from(normal: Normal<T>) -> Self {
102 Self {
103 mean: normal.mean(),
104 std: normal.std_dev(),
105 }
106 }
107}