1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
//! Kernel density estimator implementation.
use std::{
fmt::Debug,
ops::{Add, Div},
};
use fastrand::Rng;
pub use self::component::Component;
use crate::{Density, Sample};
mod component;
/// [Kernel density estimator][1].
///
/// It is used to model «good» and «bad» parameter distributions, but can also be used standalone.
///
/// # Type parameters
///
/// - [`C`]: iterator of KDE's components that are [`Density`] and [`Sample`].
///
/// [1]: https://en.wikipedia.org/wiki/Kernel_density_estimation
#[derive(Copy, Clone, Debug)]
pub struct KernelDensityEstimator<C>(pub C);
impl<P, D, C> Density<P, D> for KernelDensityEstimator<C>
where
C: Iterator + Clone,
C::Item: Density<P, D>,
P: Copy,
D: Add<Output = D> + Div<Output = D> + num_traits::FromPrimitive + num_traits::Zero,
{
/// Calculate the KDE's density at the specified point.
///
/// The method returns [`P::zero()`], if there are no components.
#[allow(clippy::cast_precision_loss)]
fn density(&self, at: P) -> D {
let (n_points, sum) = self
.0
.clone()
.fold((0_usize, D::zero()), |(n, sum), component| {
(n + 1, sum + component.density(at))
});
if n_points == 0 {
D::zero()
} else {
sum / D::from_usize(n_points).unwrap()
}
}
}
impl<T, C> Sample<Option<T>> for KernelDensityEstimator<C>
where
C: Iterator + Clone,
C::Item: Sample<T>,
{
/// Sample a random point from the KDE.
///
/// The algorithm uses «[reservoir sampling][1]» to pick a random component,
/// and then samples a point from that component.
///
/// The method returns [`None`], if the estimator has no components.
///
/// [1]: https://en.wikipedia.org/wiki/Reservoir_sampling
fn sample(&self, rng: &mut Rng) -> Option<T> {
let sample = self
.0
.clone()
.enumerate()
.filter(|(i, _)| rng.usize(0..=*i) == 0)
.last()?
.1
.sample(rng);
Some(sample)
}
}
#[cfg(test)]
mod tests {
use std::iter;
use super::*;
use crate::{consts::f64::SQRT_3, kernel::continuous::Uniform};
#[test]
fn sample_single_component_ok() {
let component = Component {
kernel: Uniform,
location: 0.0,
bandwidth: 1.0,
};
let kde = KernelDensityEstimator(iter::once(component));
let mut rng = Rng::new();
let sample = kde.sample(&mut rng).unwrap();
assert!((-SQRT_3..=SQRT_3).contains(&sample));
// Ensure that the iterator can be reused.
let _ = kde.sample(&mut rng).unwrap();
}
}