disty_cli/
kde.rs

1/// Simple Gaussian Kernel Density Estimator
2/// TODO make this even faster by porting the fast-kde paper cited at https://github.com/uwdata/fast-kde
3#[allow(clippy::upper_case_acronyms)]
4pub struct KDE<'a> {
5    data: &'a [f64],
6    bandwidth: f64,
7}
8
9impl<'a> KDE<'a> {
10    /// Create a KDE with automatic bandwidth selection (Silverman's rule)
11    /// Assumes data is already sorted
12    pub fn new(data: &'a [f64]) -> Self {
13        let n = data.len() as f64;
14
15        let mean = data.iter().sum::<f64>() / n;
16        let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
17        let std_dev = variance.sqrt();
18
19        // Silverman's rule of thumb: h ≈ 1.06 * σ * n^(-1/5)
20        let bandwidth = 1.06 * std_dev * n.powf(-0.2);
21
22        KDE { data, bandwidth }
23    }
24
25    /// Probability density at x
26    pub fn pdf(&self, x: f64) -> f64 {
27        let n = self.data.len() as f64;
28        let h = self.bandwidth;
29
30        // Optimization: Only consider points within ~4 bandwidths
31        // Beyond that, gaussian kernel contribution is < 0.00003 (negligible)
32        let cutoff = 4.0 * h;
33        let lower = x - cutoff;
34        let upper = x + cutoff;
35
36        // Binary search to find the range of relevant points (data is sorted)
37        let start_idx = self.data.partition_point(|&xi| xi < lower);
38        let end_idx = self.data.partition_point(|&xi| xi <= upper);
39
40        let sum: f64 = self.data[start_idx..end_idx]
41            .iter()
42            .map(|&xi| gaussian_kernel((x - xi) / h))
43            .sum();
44
45        sum / (n * h)
46    }
47
48    /// Get bounds for plotting (data range + 10% padding)
49    pub fn bounds(&self) -> (f64, f64) {
50        let min = self.data.first().copied().unwrap_or(0.0);
51        let max = self.data.last().copied().unwrap_or(1.0);
52        let padding = (max - min) * 0.1;
53
54        // Clamp lower bound to 0 if all data is non-negative
55        let lower = if min >= 0.0 {
56            (min - padding).max(0.0)
57        } else {
58            min - padding
59        };
60
61        (lower, max + padding)
62    }
63}
64
65/// Standard Gaussian kernel: K(u) = (1/√(2π)) * e^(-u²/2)
66fn gaussian_kernel(u: f64) -> f64 {
67    // We can't use sqrt in const contexts still :(
68    const INV_SQRT_2PI: f64 = 0.3989422804014327;
69    INV_SQRT_2PI * (-0.5 * u * u).exp()
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75
76    #[test]
77    fn test_gaussian_kernel_at_zero() {
78        let result = gaussian_kernel(0.0);
79        // At u=0, K(0) = 1/√(2π) ≈ 0.3989
80        assert!((result - 0.3989422804014327).abs() < 1e-10);
81    }
82
83    #[test]
84    fn test_gaussian_kernel_symmetric() {
85        let u = 1.5;
86        assert_eq!(gaussian_kernel(u), gaussian_kernel(-u));
87    }
88
89    #[test]
90    fn test_gaussian_kernel_decreases() {
91        // Kernel should decrease as we move away from 0
92        assert!(gaussian_kernel(0.0) > gaussian_kernel(1.0));
93        assert!(gaussian_kernel(1.0) > gaussian_kernel(2.0));
94        assert!(gaussian_kernel(2.0) > gaussian_kernel(3.0));
95    }
96
97    #[test]
98    fn test_kde_new_simple() {
99        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
100        let kde = KDE::new(&data);
101
102        // Data should match input (already sorted)
103        assert_eq!(kde.data, &[1.0, 2.0, 3.0, 4.0, 5.0]);
104
105        // Bandwidth should be positive
106        assert!(kde.bandwidth > 0.0);
107    }
108
109    #[test]
110    fn test_kde_new_sorted_input() {
111        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
112        let kde = KDE::new(&data);
113
114        // Data should remain sorted
115        assert_eq!(kde.data, &[1.0, 2.0, 3.0, 4.0, 5.0]);
116    }
117
118    #[test]
119    fn test_kde_pdf_at_data_point() {
120        let data = vec![1.0, 2.0, 3.0];
121        let kde = KDE::new(&data);
122
123        // PDF at actual data points should be positive
124        assert!(kde.pdf(1.0) > 0.0);
125        assert!(kde.pdf(2.0) > 0.0);
126        assert!(kde.pdf(3.0) > 0.0);
127    }
128
129    #[test]
130    fn test_kde_pdf_peak_at_mean() {
131        let data = vec![1.8, 1.9, 2.0, 2.1, 2.2];
132        let kde = KDE::new(&data);
133
134        // PDF should be highest near the mean
135        let pdf_at_mean = kde.pdf(2.0);
136        let pdf_away = kde.pdf(5.0);
137        assert!(pdf_at_mean > pdf_away);
138    }
139
140    #[test]
141    fn test_kde_pdf_decreases_away_from_data() {
142        let data = vec![4.0, 4.5, 5.0, 5.5, 6.0]; // Clustered around 5.0 with more spread
143        let kde = KDE::new(&data);
144
145        let center = kde.pdf(5.0);
146        let far = kde.pdf(15.0);
147
148        // PDF should be much higher at the center than far away
149        assert!(center > far);
150    }
151
152    #[test]
153    fn test_kde_bounds_simple() {
154        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
155        let kde = KDE::new(&data);
156        let (min, max) = kde.bounds();
157
158        // Bounds should include data range with padding
159        assert!(min <= 1.0);
160        assert!(max >= 5.0);
161
162        // But should have some padding (10% of range)
163        let range = 5.0 - 1.0;
164        let expected_padding = range * 0.1;
165        assert!((max - 5.0) >= expected_padding * 0.99); // Allow small FP error
166    }
167
168    #[test]
169    fn test_kde_bounds_non_negative_data() {
170        let data = vec![1.0, 2.0, 3.0];
171        let kde = KDE::new(&data);
172        let (min, _) = kde.bounds();
173
174        // Lower bound should be clamped to 0 for non-negative data
175        assert!(min >= 0.0);
176    }
177
178    #[test]
179    fn test_kde_bounds_negative_data() {
180        let data = vec![-5.0, -2.0, 1.0];
181        let kde = KDE::new(&data);
182        let (min, _) = kde.bounds();
183
184        // Lower bound can go negative for data with negative values
185        assert!(min < -5.0); // Should have padding below -5.0
186    }
187
188    #[test]
189    fn test_kde_bandwidth_silverman() {
190        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
191        let n = data.len() as f64;
192        let mean = data.iter().sum::<f64>() / n;
193        let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
194        let std_dev = variance.sqrt();
195        let expected_bandwidth = 1.06 * std_dev * n.powf(-0.2);
196
197        let kde = KDE::new(&data);
198        assert!((kde.bandwidth - expected_bandwidth).abs() < 1e-10);
199    }
200
201    #[test]
202    fn test_kde_pdf_bimodal() {
203        // Two clusters of points
204        let data = vec![1.0, 1.1, 1.2, 5.0, 5.1, 5.2];
205        let kde = KDE::new(&data);
206
207        // PDF should have peaks near each cluster
208        let pdf_cluster1 = kde.pdf(1.1);
209        let pdf_cluster2 = kde.pdf(5.1);
210        let pdf_middle = kde.pdf(3.0);
211
212        // Peaks should be higher than middle
213        assert!(pdf_cluster1 > pdf_middle);
214        assert!(pdf_cluster2 > pdf_middle);
215    }
216}