use crate::error::{Error, Result};
#[derive(Clone, Debug)]
pub(crate) struct KernelDensityEstimator {
samples: Vec<f64>,
bandwidth: f64,
}
impl KernelDensityEstimator {
pub(crate) fn new(samples: Vec<f64>) -> Result<Self> {
if samples.is_empty() {
return Err(Error::EmptySamples);
}
let bandwidth = Self::scotts_rule(&samples);
Ok(Self { samples, bandwidth })
}
pub(crate) fn with_bandwidth(samples: Vec<f64>, bandwidth: f64) -> Result<Self> {
if samples.is_empty() {
return Err(Error::EmptySamples);
}
if bandwidth <= 0.0 {
return Err(Error::InvalidBandwidth(bandwidth));
}
Ok(Self { samples, bandwidth })
}
#[allow(clippy::cast_precision_loss)]
fn scotts_rule(samples: &[f64]) -> f64 {
let n = samples.len() as f64;
let std_dev = Self::sample_std_dev(samples);
if std_dev < f64::EPSILON {
return 1.0;
}
n.powf(-0.2) * std_dev
}
#[allow(clippy::cast_precision_loss)]
fn sample_std_dev(samples: &[f64]) -> f64 {
let n = samples.len() as f64;
let mean = samples.iter().sum::<f64>() / n;
let variance = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
variance.sqrt()
}
#[allow(clippy::cast_precision_loss)]
pub(crate) fn pdf(&self, x: f64) -> f64 {
let n = self.samples.len() as f64;
let inv_bandwidth = 1.0 / self.bandwidth;
let normalization = inv_bandwidth / (2.0 * core::f64::consts::PI).sqrt();
let density: f64 = self
.samples
.iter()
.map(|&xi| {
let z = (x - xi) * inv_bandwidth;
normalization * (-0.5 * z * z).exp()
})
.sum();
density / n
}
pub(crate) fn sample(&self, rng: &mut fastrand::Rng) -> f64 {
let idx = rng.usize(0..self.samples.len());
let center = self.samples[idx];
let u1: f64 = rng.f64();
let u2: f64 = rng.f64();
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * core::f64::consts::PI * u2).cos();
center + z * self.bandwidth
}
#[cfg(test)]
pub(crate) fn bandwidth(&self) -> f64 {
self.bandwidth
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kde_pdf_basic() {
let samples = vec![0.0, 1.0, 2.0];
let kde = KernelDensityEstimator::new(samples).unwrap();
assert!(kde.pdf(0.0) > 0.0);
assert!(kde.pdf(1.0) > 0.0);
assert!(kde.pdf(2.0) > 0.0);
let mid_density = kde.pdf(1.0);
let far_density = kde.pdf(10.0);
assert!(mid_density > far_density);
}
#[test]
fn test_kde_pdf_integrates_to_one() {
let samples = vec![0.0, 1.0, 2.0, 3.0, 4.0];
let kde = KernelDensityEstimator::new(samples).unwrap();
let n_points = 10000;
let low = -10.0;
let high = 15.0;
let dx = (high - low) / f64::from(n_points);
let integral: f64 = (0..n_points)
.map(|i| {
let x = low + (f64::from(i) + 0.5) * dx;
kde.pdf(x) * dx
})
.sum();
assert!(
(integral - 1.0).abs() < 0.01,
"Integral = {integral}, expected ~1.0"
);
}
#[test]
fn test_kde_with_bandwidth() {
let samples = vec![0.0, 1.0, 2.0];
let kde = KernelDensityEstimator::with_bandwidth(samples, 0.5).unwrap();
assert!((kde.bandwidth() - 0.5).abs() < f64::EPSILON);
assert!(kde.pdf(1.0) > 0.0);
}
#[test]
fn test_kde_sample_in_reasonable_range() {
let samples = vec![0.0, 1.0, 2.0, 3.0, 4.0];
let kde = KernelDensityEstimator::new(samples).unwrap();
let mut rng = fastrand::Rng::new();
for _ in 0..100 {
let s = kde.sample(&mut rng);
assert!(s > -10.0 && s < 15.0, "Sample {s} outside expected range");
}
}
#[test]
fn test_kde_single_sample() {
let samples = vec![5.0];
let kde = KernelDensityEstimator::new(samples).unwrap();
assert!(kde.pdf(5.0) > 0.0);
assert!(kde.pdf(4.5) > 0.0);
}
#[test]
fn test_kde_identical_samples() {
let samples = vec![3.0, 3.0, 3.0, 3.0];
let kde = KernelDensityEstimator::new(samples).unwrap();
assert!(kde.bandwidth() > 0.0);
assert!(kde.pdf(3.0) > 0.0);
}
#[test]
fn test_scotts_rule_bandwidth() {
let samples = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let kde = KernelDensityEstimator::new(samples).unwrap();
let bandwidth = kde.bandwidth();
assert!(
bandwidth > 1.0 && bandwidth < 3.0,
"Bandwidth {bandwidth} outside expected range"
);
}
#[test]
fn test_kde_empty_samples() {
let samples: Vec<f64> = vec![];
let result = KernelDensityEstimator::new(samples);
assert!(matches!(result, Err(Error::EmptySamples)));
}
#[test]
fn test_kde_zero_bandwidth() {
let samples = vec![1.0, 2.0, 3.0];
let result = KernelDensityEstimator::with_bandwidth(samples, 0.0);
assert!(matches!(result, Err(Error::InvalidBandwidth(_))));
}
#[test]
fn test_kde_negative_bandwidth() {
let samples = vec![1.0, 2.0, 3.0];
let result = KernelDensityEstimator::with_bandwidth(samples, -1.0);
assert!(matches!(result, Err(Error::InvalidBandwidth(_))));
}
}