Skip to main content

flow_utils/kde/
mod.rs

1//! Kernel Density Estimation (KDE) module
2//!
3//! Provides FFT-accelerated KDE with optional GPU support.
4
5mod fft;
6mod kde2d;
7#[cfg(feature = "gpu")]
8mod gpu;
9
10use crate::common::{interquartile_range, standard_deviation};
11use thiserror::Error;
12
13pub use fft::kde_fft;
14pub use kde2d::KernelDensity2D;
15#[cfg(feature = "gpu")]
16pub use gpu::kde_fft_gpu;
17
18/// Error type for KDE operations
19#[derive(Error, Debug)]
20pub enum KdeError {
21    #[error("Empty data for KDE")]
22    EmptyData,
23    #[error("Insufficient data: need at least {min} points, got {actual}")]
24    InsufficientData { min: usize, actual: usize },
25    #[error("Statistics error: {0}")]
26    StatsError(String),
27    #[error("FFT error: {0}")]
28    FftError(String),
29}
30
31pub type KdeResult<T> = Result<T, KdeError>;
32
33/// Kernel Density Estimation using Gaussian kernel with FFT acceleration
34///
35/// This is a simplified implementation of R's density() function
36/// with automatic bandwidth selection using Silverman's rule of thumb.
37/// Uses FFT-based convolution for O(n log n) performance instead of O(n*m).
38pub struct KernelDensity {
39    /// Grid points
40    pub x: Vec<f64>,
41    /// Density values
42    pub y: Vec<f64>,
43}
44
45impl KernelDensity {
46    /// Compute kernel density estimate using FFT-based convolution
47    ///
48    /// # Arguments
49    /// * `data` - Input data
50    /// * `adjust` - Bandwidth adjustment factor (default: 1.0)
51    /// * `n_points` - Number of grid points (default: 512)
52    pub fn estimate(data: &[f64], adjust: f64, n_points: usize) -> KdeResult<Self> {
53        if data.is_empty() {
54            return Err(KdeError::EmptyData);
55        }
56
57        // Remove NaN values
58        let clean_data: Vec<f64> = data.iter().filter(|x| x.is_finite()).copied().collect();
59
60        if clean_data.len() < 3 {
61            return Err(KdeError::InsufficientData {
62                min: 3,
63                actual: clean_data.len(),
64            });
65        }
66
67        // Calculate bandwidth using Silverman's rule of thumb
68        let n = clean_data.len() as f64;
69        let std_dev = standard_deviation(&clean_data)
70            .map_err(|e| KdeError::StatsError(e))?;
71        let iqr = interquartile_range(&clean_data)
72            .map_err(|e| KdeError::StatsError(e))?;
73
74        // Silverman's rule: bw = 0.9 * min(sd, IQR/1.34) * n^(-1/5)
75        let bw_factor = 0.9 * std_dev.min(iqr / 1.34) * n.powf(-0.2);
76        let bandwidth = bw_factor * adjust;
77
78        // Create grid
79        let data_min = clean_data.iter().cloned().fold(f64::INFINITY, f64::min);
80        let data_max = clean_data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
81        let grid_min = data_min - 3.0 * bandwidth;
82        let grid_max = data_max + 3.0 * bandwidth;
83
84        let x: Vec<f64> = (0..n_points)
85            .map(|i| grid_min + (grid_max - grid_min) * (i as f64) / (n_points - 1) as f64)
86            .collect();
87
88        // Use FFT-based KDE for better performance
89        // Use GPU if available (batched operations provide speedup even for smaller datasets)
90        #[cfg(feature = "gpu")]
91        let y = if crate::gpu::is_gpu_available() {
92            kde_fft_gpu(&clean_data, &x, bandwidth, n)?
93        } else {
94            kde_fft(&clean_data, &x, bandwidth, n)?
95        };
96        
97        #[cfg(not(feature = "gpu"))]
98        let y = kde_fft(&clean_data, &x, bandwidth, n)?;
99
100        Ok(KernelDensity { x, y })
101    }
102
103    /// Find local maxima (peaks) in the density estimate
104    ///
105    /// # Arguments
106    /// * `peak_removal` - Minimum peak height as fraction of max density
107    ///
108    /// # Returns
109    /// Vector of x-coordinates where peaks occur
110    pub fn find_peaks(&self, peak_removal: f64) -> Vec<f64> {
111        if self.y.len() < 3 {
112            return Vec::new();
113        }
114
115        let max_y = self.y.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
116        let threshold = peak_removal * max_y;
117
118        let mut peaks = Vec::new();
119
120        for i in 1..self.y.len() - 1 {
121            // Check if this is a local maximum above threshold
122            if self.y[i] > self.y[i - 1] && self.y[i] > self.y[i + 1] && self.y[i] > threshold {
123                peaks.push(self.x[i]);
124            }
125        }
126
127        // If no peaks found, return the maximum point
128        if peaks.is_empty() {
129            if let Some((idx, _)) = self
130                .y
131                .iter()
132                .enumerate()
133                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
134            {
135                peaks.push(self.x[idx]);
136            }
137        }
138
139        peaks
140    }
141
142    /// Get density value at a specific point using linear interpolation
143    ///
144    /// # Arguments
145    /// * `x` - The point at which to evaluate the density
146    ///
147    /// # Returns
148    /// The interpolated density value, or 0.0 if x is outside the grid range
149    pub fn density_at(&self, x: f64) -> f64 {
150        if self.x.is_empty() || self.y.is_empty() {
151            return 0.0;
152        }
153
154        // Handle out-of-bounds
155        if x <= self.x[0] {
156            return self.y[0];
157        }
158        if x >= self.x[self.x.len() - 1] {
159            return self.y[self.y.len() - 1];
160        }
161
162        // Find the two grid points to interpolate between
163        let mut left_idx = 0;
164        let mut right_idx = self.x.len() - 1;
165
166        // Binary search for the interval
167        while right_idx - left_idx > 1 {
168            let mid = (left_idx + right_idx) / 2;
169            if self.x[mid] <= x {
170                left_idx = mid;
171            } else {
172                right_idx = mid;
173            }
174        }
175
176        // Linear interpolation
177        let x0 = self.x[left_idx];
178        let x1 = self.x[right_idx];
179        let y0 = self.y[left_idx];
180        let y1 = self.y[right_idx];
181
182        if (x1 - x0).abs() < 1e-10 {
183            y0
184        } else {
185            y0 + (y1 - y0) * (x - x0) / (x1 - x0)
186        }
187    }
188}