concision_core/init/distr/
trunc.rs1use num::traits::Float;
6use rand::Rng;
7use rand_distr::{Distribution, Normal, NormalError, StandardNormal};
8
9#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
12pub struct TruncatedNormal<F>
13where
14 StandardNormal: Distribution<F>,
15{
16 mean: F,
17 std: F,
18}
19
20impl<F> TruncatedNormal<F>
21where
22 F: Float,
23 StandardNormal: Distribution<F>,
24{
25 pub fn new(mean: F, std: F) -> Result<Self, NormalError> {
27 Ok(Self { mean, std })
28 }
29
30 pub(crate) fn boundary(&self) -> F {
31 self.mean() + self.std_dev() * F::from(2).unwrap()
32 }
33
34 pub(crate) fn score(&self, x: F) -> F {
35 self.mean() - self.std_dev() * x
36 }
37
38 pub fn distr(&self) -> Normal<F> {
39 Normal::new(self.mean(), self.std_dev()).unwrap()
40 }
41
42 pub fn mean(&self) -> F {
43 self.mean
44 }
45
46 pub fn std_dev(&self) -> F {
47 self.std
48 }
49}
50
51impl<F> Distribution<F> for TruncatedNormal<F>
52where
53 F: Float,
54 StandardNormal: Distribution<F>,
55{
56 fn sample<R>(&self, rng: &mut R) -> F
57 where
58 R: Rng + ?Sized,
59 {
60 let bnd = self.boundary();
61 let mut x = self.score(rng.sample(StandardNormal));
62 while x < -bnd || x > bnd {
64 x = self.score(rng.sample(StandardNormal));
65 }
66 x
67 }
68}
69
70impl<F> From<Normal<F>> for TruncatedNormal<F>
71where
72 F: Float,
73 StandardNormal: Distribution<F>,
74{
75 fn from(normal: Normal<F>) -> Self {
76 Self {
77 mean: normal.mean(),
78 std: normal.std_dev(),
79 }
80 }
81}