hyperopt/kernel/continuous/
epanechnikov.rs1use std::fmt::Debug;
2
3use fastrand::Rng;
4
5use crate::{
6 constants::{ConstSqrt5, ConstThreeQuarters},
7 kernel::{Density, Kernel, Sample},
8 traits::{
9 loopback::{SelfAdd, SelfDiv, SelfMul, SelfNeg, SelfSub},
10 shortcuts::Multiplicative,
11 },
12};
13
14#[derive(Copy, Clone, Debug)]
18pub struct Epanechnikov<T> {
19 location: T,
20 std: T,
21}
22
23impl<T> Density for Epanechnikov<T>
24where
25 T: SelfSub
26 + Multiplicative
27 + Copy
28 + PartialOrd
29 + SelfNeg
30 + SelfDiv
31 + num_traits::One
32 + num_traits::Zero
33 + ConstSqrt5
34 + ConstThreeQuarters,
35{
36 type Param = T;
37 type Output = T;
38
39 fn density(&self, at: Self::Param) -> Self::Output {
40 let normalized = (at - self.location) / self.std / T::SQRT_5;
42 if (-T::one()..=T::one()).contains(&normalized) {
43 T::THREE_QUARTERS / T::SQRT_5 * (T::one() - normalized * normalized) / self.std
45 } else {
46 T::zero()
48 }
49 }
50}
51
52impl<T> Sample for Epanechnikov<T>
53where
54 T: Copy + SelfAdd + SelfMul + TryFrom<f64>,
55 <T as TryFrom<f64>>::Error: Debug,
56{
57 type Param = T;
58
59 fn sample(&self, rng: &mut Rng) -> Self::Param {
63 let (x1, x2) = min_2(rng.f64(), rng.f64(), rng.f64());
65
66 let abs_normalized = if rng.bool() { x1 } else { x2 };
68
69 let normalized = if rng.bool() {
71 abs_normalized
72 } else {
73 -abs_normalized
74 };
75
76 self.location + self.std * T::try_from(normalized * f64::SQRT_5).unwrap()
78 }
79}
80
81impl<T> Kernel for Epanechnikov<T>
82where
83 Self: Density<Param = T, Output = T> + Sample<Param = T>,
84 T: PartialOrd + num_traits::Zero,
85{
86 type Param = T;
87
88 fn new(location: T, std: T) -> Self {
89 assert!(std > T::zero());
90 Self { location, std }
91 }
92}
93
94impl<T> Default for Epanechnikov<T>
95where
96 T: num_traits::Zero + num_traits::One,
97{
98 fn default() -> Self {
99 Self {
100 location: T::zero(),
101 std: T::one(),
102 }
103 }
104}
105
106fn min_2<T: PartialOrd>(mut x1: T, mut x2: T, x3: T) -> (T, T) {
108 if x1 > x2 {
110 (x1, x2) = (x2, x1);
111 }
112
113 (x1, if x2 > x3 { x3 } else { x2 })
115}
116
117#[cfg(test)]
118mod tests {
119 use approx::assert_abs_diff_eq;
120
121 use super::*;
122
123 #[test]
124 fn density_inside_ok() {
125 let kernel = Epanechnikov::<f64>::default();
126 assert_abs_diff_eq!(kernel.density(0.0), 0.335_410_196_624_968_46);
127 assert_abs_diff_eq!(kernel.density(f64::SQRT_5), 0.0);
128 assert_abs_diff_eq!(kernel.density(-f64::SQRT_5), 0.0);
129 }
130
131 #[test]
132 #[allow(clippy::float_cmp)]
133 fn density_outside_ok() {
134 let kernel = Epanechnikov::<f64>::default();
135 assert_eq!(kernel.density(-10.0), 0.0);
136 assert_eq!(kernel.density(10.0), 0.0);
137 }
138
139 #[test]
140 fn min_2_ok() {
141 assert_eq!(min_2(1, 2, 3), (1, 2));
142 assert_eq!(min_2(1, 3, 2), (1, 2));
143 assert_eq!(min_2(2, 1, 3), (1, 2));
144 assert_eq!(min_2(2, 3, 1), (2, 1));
145 assert_eq!(min_2(3, 1, 2), (1, 2));
146 assert_eq!(min_2(3, 2, 1), (2, 1));
147 }
148}