concision_init/distr/
trunc.rs1use num::traits::Float;
6use rand::{Rng, RngCore};
7use rand_distr::{Distribution, Normal, StandardNormal};
8
9#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
12pub struct TruncatedNormal<T>
13where
14 StandardNormal: Distribution<T>,
15{
16 mean: T,
17 std: T,
18}
19
20impl<T> TruncatedNormal<T>
21where
22 T: Float,
23 StandardNormal: Distribution<T>,
24{
25 pub const fn new(mean: T, std: T) -> crate::Result<Self> {
28 Ok(Self { mean, std })
29 }
30 pub(crate) fn boundary(&self) -> T {
36 self.mean() + self.std_dev() * T::from(2).unwrap()
37 }
38
39 pub(crate) fn score(&self, x: T) -> T {
40 self.mean() - self.std_dev() * x
41 }
42
43 pub fn distr(&self) -> Normal<T> {
44 Normal::new(self.mean(), self.std_dev()).unwrap()
45 }
46
47 pub fn mean(&self) -> T {
48 self.mean
49 }
50
51 pub fn std_dev(&self) -> T {
52 self.std
53 }
54}
55
56impl<T> Distribution<T> for TruncatedNormal<T>
57where
58 T: Float,
59 StandardNormal: Distribution<T>,
60{
61 fn sample<R>(&self, rng: &mut R) -> T
62 where
63 R: RngCore + ?Sized,
64 {
65 let bnd = self.boundary();
66 let mut x = self.score(rng.sample(StandardNormal));
67 while x < -bnd || x > bnd {
69 x = self.score(rng.sample(StandardNormal));
70 }
71 x
72 }
73}
74
75impl<T> From<Normal<T>> for TruncatedNormal<T>
76where
77 T: Float,
78 StandardNormal: Distribution<T>,
79{
80 fn from(normal: Normal<T>) -> Self {
81 Self {
82 mean: normal.mean(),
83 std: normal.std_dev(),
84 }
85 }
86}