hyperopt/kernel/continuous/
gaussian.rs

1use std::{f64::consts::TAU, fmt::Debug};
2
3use fastrand::Rng;
4use num_traits::FromPrimitive;
5
6use crate::{
7    constants::{ConstFrac1SqrtTau, ConstOneHalf},
8    kernel::{Density, Kernel, Sample},
9    traits::{
10        loopback::{SelfAdd, SelfExp, SelfMul, SelfNeg, SelfSub},
11        shortcuts::Multiplicative,
12    },
13};
14
15/// [Gaussian][1] kernel.
16///
17/// [1]: https://en.wikipedia.org/wiki/Normal_distribution
18#[derive(Copy, Clone, Debug)]
19pub struct Gaussian<T> {
20    location: T,
21    std: T,
22}
23
24impl<T> Density for Gaussian<T>
25where
26    T: Copy + ConstFrac1SqrtTau + SelfSub + Multiplicative + ConstOneHalf + SelfExp + SelfNeg,
27{
28    type Param = T;
29    type Output = T;
30
31    fn density(&self, at: Self::Param) -> Self::Output {
32        let normalized = (at - self.location) / self.std;
33        T::FRAC_1_SQRT_TAU * (-T::ONE_HALF * normalized * normalized).exp() / self.std
34    }
35}
36
37impl<T> Sample for Gaussian<T>
38where
39    T: Copy + SelfAdd + SelfMul + FromPrimitive,
40{
41    type Param = T;
42
43    /// [Generate a sample][1] from the Gaussian kernel.
44    ///
45    /// [1]: https://en.wikipedia.org/wiki/Box–Muller_transform
46    fn sample(&self, rng: &mut Rng) -> Self::Param {
47        let u1 = rng.f64();
48        let u2 = rng.f64();
49        let normalized = T::from_f64((-2.0 * u1.ln()).sqrt() * (TAU * u2).cos()).unwrap();
50        self.location + self.std * normalized
51    }
52}
53
54impl<T> Kernel for Gaussian<T>
55where
56    Self: Density<Param = T, Output = T> + Sample<Param = T>,
57    T: PartialOrd + num_traits::Zero,
58{
59    type Param = T;
60
61    fn new(location: T, std: T) -> Self {
62        assert!(std > T::zero());
63        Self { location, std }
64    }
65}
66
67impl<T> Default for Gaussian<T>
68where
69    T: num_traits::Zero + num_traits::One,
70{
71    /// Zero-centered standard normal distribution.
72    fn default() -> Self {
73        Self {
74            location: T::zero(),
75            std: T::one(),
76        }
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use approx::assert_abs_diff_eq;
83
84    use super::*;
85
86    #[test]
87    fn density_ok() {
88        let kernel = Gaussian::default();
89        assert_abs_diff_eq!(kernel.density(0.0), 0.398_942_280_401_432_7,);
90        assert_abs_diff_eq!(kernel.density(1.0), 0.241_970_724_519_143_37,);
91        assert_abs_diff_eq!(kernel.density(-1.0), 0.241_970_724_519_143_37,);
92    }
93}