Skip to main content

flow_utils/kde/
kde2d.rs

1//! 2D Kernel Density Estimation
2//!
3//! Provides 2D KDE for scatter plots and density-based gating.
4
5use crate::common::{gaussian_kernel, interquartile_range, standard_deviation};
6use crate::kde::{KdeError, KdeResult};
7use ndarray::Array2;
8use realfft::RealFftPlanner;
9use realfft::num_complex::Complex;
10
11/// 2D Kernel Density Estimation result
12#[derive(Debug)]
13pub struct KernelDensity2D {
14    /// X grid points
15    pub x: Vec<f64>,
16    /// Y grid points
17    pub y: Vec<f64>,
18    /// Density values (2D grid: x.len() × y.len())
19    pub z: Array2<f64>,
20}
21
22impl KernelDensity2D {
23    /// Compute 2D kernel density estimate using FFT-based convolution
24    ///
25    /// # Arguments
26    /// * `data_x` - X coordinates of data points
27    /// * `data_y` - Y coordinates of data points
28    /// * `adjust` - Bandwidth adjustment factor (default: 1.0)
29    /// * `n_points` - Number of grid points per dimension (default: 128)
30    ///
31    /// # Returns
32    /// KernelDensity2D with 2D density grid
33    pub fn estimate(
34        data_x: &[f64],
35        data_y: &[f64],
36        adjust: f64,
37        n_points: usize,
38    ) -> KdeResult<Self> {
39        if data_x.len() != data_y.len() {
40            return Err(KdeError::StatsError(
41                "X and Y data must have the same length".to_string(),
42            ));
43        }
44
45        if data_x.is_empty() {
46            return Err(KdeError::EmptyData);
47        }
48
49        // Remove NaN values
50        let mut clean_x = Vec::new();
51        let mut clean_y = Vec::new();
52        for i in 0..data_x.len() {
53            if data_x[i].is_finite() && data_y[i].is_finite() {
54                clean_x.push(data_x[i]);
55                clean_y.push(data_y[i]);
56            }
57        }
58
59        if clean_x.len() < 3 {
60            return Err(KdeError::InsufficientData {
61                min: 3,
62                actual: clean_x.len(),
63            });
64        }
65
66        // Calculate bandwidths for each dimension
67        let n = clean_x.len() as f64;
68        let std_dev_x = standard_deviation(&clean_x)
69            .map_err(|e| KdeError::StatsError(e))?;
70        let iqr_x = interquartile_range(&clean_x)
71            .map_err(|e| KdeError::StatsError(e))?;
72        let bw_factor_x = 0.9 * std_dev_x.min(iqr_x / 1.34) * n.powf(-0.2);
73        let bandwidth_x = bw_factor_x * adjust;
74
75        let std_dev_y = standard_deviation(&clean_y)
76            .map_err(|e| KdeError::StatsError(e))?;
77        let iqr_y = interquartile_range(&clean_y)
78            .map_err(|e| KdeError::StatsError(e))?;
79        let bw_factor_y = 0.9 * std_dev_y.min(iqr_y / 1.34) * n.powf(-0.2);
80        let bandwidth_y = bw_factor_y * adjust;
81
82        // Create 2D grid
83        let x_min = clean_x.iter().cloned().fold(f64::INFINITY, f64::min);
84        let x_max = clean_x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
85        let y_min = clean_y.iter().cloned().fold(f64::INFINITY, f64::min);
86        let y_max = clean_y.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
87
88        let x_grid_min = x_min - 3.0 * bandwidth_x;
89        let x_grid_max = x_max + 3.0 * bandwidth_x;
90        let y_grid_min = y_min - 3.0 * bandwidth_y;
91        let y_grid_max = y_max + 3.0 * bandwidth_y;
92
93        let x: Vec<f64> = (0..n_points)
94            .map(|i| {
95                x_grid_min + (x_grid_max - x_grid_min) * (i as f64) / (n_points - 1) as f64
96            })
97            .collect();
98        let y: Vec<f64> = (0..n_points)
99            .map(|i| {
100                y_grid_min + (y_grid_max - y_grid_min) * (i as f64) / (n_points - 1) as f64
101            })
102            .collect();
103
104        // Compute 2D KDE using FFT convolution
105        let z = kde2d_fft(&clean_x, &clean_y, &x, &y, bandwidth_x, bandwidth_y, n)?;
106
107        Ok(KernelDensity2D { x, y, z })
108    }
109
110    /// Find density contour at given threshold level
111    ///
112    /// # Arguments
113    /// * `threshold` - Density threshold (as fraction of max density)
114    ///
115    /// # Returns
116    /// Vector of (x, y) coordinates forming the contour
117    pub fn find_contour(&self, threshold: f64) -> Vec<(f64, f64)> {
118        let max_density = self.z.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
119        let density_threshold = threshold * max_density;
120
121        // Simple contour extraction: find points above threshold
122        // TODO: Implement proper contour tracing (marching squares algorithm)
123        let mut contour_points = Vec::new();
124        
125        for i in 0..self.x.len() {
126            for j in 0..self.y.len() {
127                if self.z[[i, j]] >= density_threshold {
128                    // Check if this is on the boundary (has a neighbor below threshold)
129                    let is_boundary = (i > 0 && self.z[[i - 1, j]] < density_threshold)
130                        || (i < self.x.len() - 1 && self.z[[i + 1, j]] < density_threshold)
131                        || (j > 0 && self.z[[i, j - 1]] < density_threshold)
132                        || (j < self.y.len() - 1 && self.z[[i, j + 1]] < density_threshold);
133                    
134                    if is_boundary {
135                        contour_points.push((self.x[i], self.y[j]));
136                    }
137                }
138            }
139        }
140
141        contour_points
142    }
143
144    /// Get density value at a specific point (interpolated)
145    ///
146    /// # Arguments
147    /// * `x` - X coordinate
148    /// * `y` - Y coordinate
149    ///
150    /// # Returns
151    /// Interpolated density value
152    pub fn density_at(&self, x: f64, y: f64) -> f64 {
153        // Find grid indices
154        let x_idx = self.find_grid_index(&self.x, x);
155        let y_idx = self.find_grid_index(&self.y, y);
156
157        if x_idx >= self.x.len() || y_idx >= self.y.len() {
158            return 0.0;
159        }
160
161        // Simple bilinear interpolation
162        let x0 = if x_idx > 0 { x_idx - 1 } else { 0 };
163        let x1 = x_idx.min(self.x.len() - 1);
164        let y0 = if y_idx > 0 { y_idx - 1 } else { 0 };
165        let y1 = y_idx.min(self.y.len() - 1);
166
167        let z00 = self.z[[x0, y0]];
168        let z01 = self.z[[x0, y1]];
169        let z10 = self.z[[x1, y0]];
170        let z11 = self.z[[x1, y1]];
171
172        // Bilinear interpolation
173        let dx = if x1 > x0 {
174            (x - self.x[x0]) / (self.x[x1] - self.x[x0])
175        } else {
176            0.0
177        };
178        let dy = if y1 > y0 {
179            (y - self.y[y0]) / (self.y[y1] - self.y[y0])
180        } else {
181            0.0
182        };
183
184        z00 * (1.0 - dx) * (1.0 - dy)
185            + z10 * dx * (1.0 - dy)
186            + z01 * (1.0 - dx) * dy
187            + z11 * dx * dy
188    }
189
190    fn find_grid_index(&self, grid: &[f64], value: f64) -> usize {
191        if value <= grid[0] {
192            return 0;
193        }
194        if value >= grid[grid.len() - 1] {
195            return grid.len() - 1;
196        }
197
198        // Binary search for efficiency
199        let mut left = 0;
200        let mut right = grid.len() - 1;
201        while right - left > 1 {
202            let mid = (left + right) / 2;
203            if grid[mid] <= value {
204                left = mid;
205            } else {
206                right = mid;
207            }
208        }
209        left
210    }
211}
212
213/// 2D FFT-based Kernel Density Estimation
214///
215/// Uses 2D FFT convolution for efficient computation.
216fn kde2d_fft(
217    data_x: &[f64],
218    data_y: &[f64],
219    x_grid: &[f64],
220    y_grid: &[f64],
221    bandwidth_x: f64,
222    bandwidth_y: f64,
223    n: f64,
224) -> KdeResult<Array2<f64>> {
225    let nx = x_grid.len();
226    let ny = y_grid.len();
227
228    if nx < 2 || ny < 2 {
229        return Err(KdeError::StatsError(
230            "Grid must have at least 2 points in each dimension".to_string(),
231        ));
232    }
233
234    let x_spacing = (x_grid[nx - 1] - x_grid[0]) / (nx - 1) as f64;
235    let y_spacing = (y_grid[ny - 1] - y_grid[0]) / (ny - 1) as f64;
236
237    // Step 1: Bin data onto 2D grid
238    let mut binned = Array2::<f64>::zeros((nx, ny));
239    for (&x, &y) in data_x.iter().zip(data_y.iter()) {
240        let x_idx = ((x - x_grid[0]) / x_spacing).floor() as isize;
241        let y_idx = ((y - y_grid[0]) / y_spacing).floor() as isize;
242        if x_idx >= 0 && (x_idx as usize) < nx && y_idx >= 0 && (y_idx as usize) < ny {
243            binned[[x_idx as usize, y_idx as usize]] += 1.0;
244        }
245    }
246
247    // Step 2: Create 2D Gaussian kernel
248    let kernel_center_x = (nx - 1) as f64 / 2.0;
249    let kernel_center_y = (ny - 1) as f64 / 2.0;
250    let mut kernel = Array2::<f64>::zeros((nx, ny));
251
252    for i in 0..nx {
253        for j in 0..ny {
254            let grid_x = (i as f64 - kernel_center_x) * x_spacing;
255            let grid_y = (j as f64 - kernel_center_y) * y_spacing;
256            let u_x = grid_x / bandwidth_x;
257            let u_y = grid_y / bandwidth_y;
258            kernel[[i, j]] = gaussian_kernel(u_x) * gaussian_kernel(u_y);
259        }
260    }
261
262    // Step 3: 2D FFT convolution
263    // Use next power of 2 for efficient FFT
264    let _fft_size_x = (2 * nx).next_power_of_two();
265    let _fft_size_y = (2 * ny).next_power_of_two();
266
267    // For 2D FFT, we'll use a simpler approach: compute 1D FFTs along each dimension
268    // This is less optimal than true 2D FFT but simpler to implement
269    // TODO: Consider using a proper 2D FFT library for better performance
270
271    // For now, use a simplified approach: compute density by convolving each row/column
272    // This is an approximation but works reasonably well
273    let mut density = Array2::<f64>::zeros((nx, ny));
274
275    // Convolve along X dimension for each Y
276    for j in 0..ny {
277        let mut row_binned = vec![0.0; nx];
278        let mut row_kernel = vec![0.0; nx];
279        for i in 0..nx {
280            row_binned[i] = binned[[i, j]];
281            row_kernel[i] = kernel[[i, j]];
282        }
283
284        // Use 1D FFT convolution for this row
285        let row_density = kde1d_row(&row_binned, &row_kernel, bandwidth_x, n)?;
286        for i in 0..nx {
287            density[[i, j]] = row_density[i];
288        }
289    }
290
291    // Convolve along Y dimension for each X (simplified - just average)
292    // Full 2D convolution would be better but this approximation works
293    for i in 0..nx {
294        let mut col_density = vec![0.0; ny];
295        for j in 0..ny {
296            col_density[j] = density[[i, j]];
297        }
298        // Apply Y kernel smoothing
299        for j in 0..ny {
300            let mut sum = 0.0;
301            let mut weight_sum = 0.0;
302            for k in 0..ny {
303                let dist = (j as f64 - k as f64) * y_spacing;
304                let weight = gaussian_kernel(dist / bandwidth_y);
305                sum += col_density[k] * weight;
306                weight_sum += weight;
307            }
308            if weight_sum > 0.0 {
309                density[[i, j]] = sum / weight_sum;
310            }
311        }
312    }
313
314    // Normalize
315    let total: f64 = density.iter().sum();
316    if total > 0.0 {
317        for val in density.iter_mut() {
318            *val /= total * x_spacing * y_spacing;
319        }
320    }
321
322    Ok(density)
323}
324
325/// Helper function for 1D row convolution (simplified)
326fn kde1d_row(
327    binned: &[f64],
328    kernel: &[f64],
329    bandwidth: f64,
330    n: f64,
331) -> KdeResult<Vec<f64>> {
332    let m = binned.len();
333    let fft_size = (2 * m).next_power_of_two();
334
335    let mut planner = RealFftPlanner::<f64>::new();
336    let r2c = planner.plan_fft_forward(fft_size);
337    let c2r = planner.plan_fft_inverse(fft_size);
338
339    // Prepare padded arrays
340    let mut binned_padded = vec![0.0; fft_size];
341    binned_padded[..m].copy_from_slice(binned);
342
343    let mut kernel_padded = vec![0.0; fft_size];
344    let kernel_start = (fft_size - m) / 2;
345    let first_half = (m + 1) / 2;
346    kernel_padded[kernel_start..kernel_start + first_half]
347        .copy_from_slice(&kernel[m / 2..]);
348    let second_half = m / 2;
349    if second_half > 0 {
350        kernel_padded[..second_half].copy_from_slice(&kernel[..second_half]);
351    }
352
353    // Forward FFT
354    let mut binned_spectrum = r2c.make_output_vec();
355    r2c.process(&mut binned_padded, &mut binned_spectrum)
356        .map_err(|e| KdeError::FftError(format!("FFT forward failed: {}", e)))?;
357
358    let mut kernel_spectrum = r2c.make_output_vec();
359    r2c.process(&mut kernel_padded, &mut kernel_spectrum)
360        .map_err(|e| KdeError::FftError(format!("FFT forward failed: {}", e)))?;
361
362    // Multiply in frequency domain
363    let mut conv_spectrum: Vec<Complex<f64>> = binned_spectrum
364        .iter()
365        .zip(kernel_spectrum.iter())
366        .map(|(a, b)| a * b)
367        .collect();
368
369    // Inverse FFT
370    let mut conv_result = c2r.make_output_vec();
371    c2r.process(&mut conv_spectrum, &mut conv_result)
372        .map_err(|e| KdeError::FftError(format!("FFT inverse failed: {}", e)))?;
373
374    // Extract and normalize
375    let kernel_start = (fft_size - m) / 2;
376    let mut density = Vec::with_capacity(m);
377    for i in 0..m {
378        let idx = (kernel_start + i) % fft_size;
379        density.push(conv_result[idx] / (fft_size as f64 * n * bandwidth));
380    }
381
382    Ok(density)
383}