1#[allow(clippy::upper_case_acronyms)]
4pub struct KDE<'a> {
5 data: &'a [f64],
6 bandwidth: f64,
7}
8
9impl<'a> KDE<'a> {
10 pub fn new(data: &'a [f64]) -> Self {
13 let n = data.len() as f64;
14
15 let mean = data.iter().sum::<f64>() / n;
16 let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
17 let std_dev = variance.sqrt();
18
19 let bandwidth = 1.06 * std_dev * n.powf(-0.2);
21
22 KDE { data, bandwidth }
23 }
24
25 pub fn pdf(&self, x: f64) -> f64 {
27 let n = self.data.len() as f64;
28 let h = self.bandwidth;
29
30 let cutoff = 4.0 * h;
33 let lower = x - cutoff;
34 let upper = x + cutoff;
35
36 let start_idx = self.data.partition_point(|&xi| xi < lower);
38 let end_idx = self.data.partition_point(|&xi| xi <= upper);
39
40 let sum: f64 = self.data[start_idx..end_idx]
41 .iter()
42 .map(|&xi| gaussian_kernel((x - xi) / h))
43 .sum();
44
45 sum / (n * h)
46 }
47
48 pub fn bounds(&self) -> (f64, f64) {
50 let min = self.data.first().copied().unwrap_or(0.0);
51 let max = self.data.last().copied().unwrap_or(1.0);
52 let padding = (max - min) * 0.1;
53
54 let lower = if min >= 0.0 {
56 (min - padding).max(0.0)
57 } else {
58 min - padding
59 };
60
61 (lower, max + padding)
62 }
63}
64
65fn gaussian_kernel(u: f64) -> f64 {
67 const INV_SQRT_2PI: f64 = 0.3989422804014327;
69 INV_SQRT_2PI * (-0.5 * u * u).exp()
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75
76 #[test]
77 fn test_gaussian_kernel_at_zero() {
78 let result = gaussian_kernel(0.0);
79 assert!((result - 0.3989422804014327).abs() < 1e-10);
81 }
82
83 #[test]
84 fn test_gaussian_kernel_symmetric() {
85 let u = 1.5;
86 assert_eq!(gaussian_kernel(u), gaussian_kernel(-u));
87 }
88
89 #[test]
90 fn test_gaussian_kernel_decreases() {
91 assert!(gaussian_kernel(0.0) > gaussian_kernel(1.0));
93 assert!(gaussian_kernel(1.0) > gaussian_kernel(2.0));
94 assert!(gaussian_kernel(2.0) > gaussian_kernel(3.0));
95 }
96
97 #[test]
98 fn test_kde_new_simple() {
99 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
100 let kde = KDE::new(&data);
101
102 assert_eq!(kde.data, &[1.0, 2.0, 3.0, 4.0, 5.0]);
104
105 assert!(kde.bandwidth > 0.0);
107 }
108
109 #[test]
110 fn test_kde_new_sorted_input() {
111 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
112 let kde = KDE::new(&data);
113
114 assert_eq!(kde.data, &[1.0, 2.0, 3.0, 4.0, 5.0]);
116 }
117
118 #[test]
119 fn test_kde_pdf_at_data_point() {
120 let data = vec![1.0, 2.0, 3.0];
121 let kde = KDE::new(&data);
122
123 assert!(kde.pdf(1.0) > 0.0);
125 assert!(kde.pdf(2.0) > 0.0);
126 assert!(kde.pdf(3.0) > 0.0);
127 }
128
129 #[test]
130 fn test_kde_pdf_peak_at_mean() {
131 let data = vec![1.8, 1.9, 2.0, 2.1, 2.2];
132 let kde = KDE::new(&data);
133
134 let pdf_at_mean = kde.pdf(2.0);
136 let pdf_away = kde.pdf(5.0);
137 assert!(pdf_at_mean > pdf_away);
138 }
139
140 #[test]
141 fn test_kde_pdf_decreases_away_from_data() {
142 let data = vec![4.0, 4.5, 5.0, 5.5, 6.0]; let kde = KDE::new(&data);
144
145 let center = kde.pdf(5.0);
146 let far = kde.pdf(15.0);
147
148 assert!(center > far);
150 }
151
152 #[test]
153 fn test_kde_bounds_simple() {
154 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
155 let kde = KDE::new(&data);
156 let (min, max) = kde.bounds();
157
158 assert!(min <= 1.0);
160 assert!(max >= 5.0);
161
162 let range = 5.0 - 1.0;
164 let expected_padding = range * 0.1;
165 assert!((max - 5.0) >= expected_padding * 0.99); }
167
168 #[test]
169 fn test_kde_bounds_non_negative_data() {
170 let data = vec![1.0, 2.0, 3.0];
171 let kde = KDE::new(&data);
172 let (min, _) = kde.bounds();
173
174 assert!(min >= 0.0);
176 }
177
178 #[test]
179 fn test_kde_bounds_negative_data() {
180 let data = vec![-5.0, -2.0, 1.0];
181 let kde = KDE::new(&data);
182 let (min, _) = kde.bounds();
183
184 assert!(min < -5.0); }
187
188 #[test]
189 fn test_kde_bandwidth_silverman() {
190 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
191 let n = data.len() as f64;
192 let mean = data.iter().sum::<f64>() / n;
193 let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
194 let std_dev = variance.sqrt();
195 let expected_bandwidth = 1.06 * std_dev * n.powf(-0.2);
196
197 let kde = KDE::new(&data);
198 assert!((kde.bandwidth - expected_bandwidth).abs() < 1e-10);
199 }
200
201 #[test]
202 fn test_kde_pdf_bimodal() {
203 let data = vec![1.0, 1.1, 1.2, 5.0, 5.1, 5.2];
205 let kde = KDE::new(&data);
206
207 let pdf_cluster1 = kde.pdf(1.1);
209 let pdf_cluster2 = kde.pdf(5.1);
210 let pdf_middle = kde.pdf(3.0);
211
212 assert!(pdf_cluster1 > pdf_middle);
214 assert!(pdf_cluster2 > pdf_middle);
215 }
216}