use crate::error::{Error, Result};
#[derive(Clone, Debug)]
pub(crate) struct MultivariateKDE {
samples: Vec<Vec<f64>>,
bandwidths: Vec<f64>,
n_dims: usize,
}
impl MultivariateKDE {
#[allow(clippy::cast_precision_loss)]
pub(crate) fn new(samples: Vec<Vec<f64>>) -> Result<Self> {
if samples.is_empty() {
return Err(Error::EmptySamples);
}
let n_dims = samples[0].len();
if n_dims == 0 {
return Err(Error::ZeroDimensions);
}
for (i, sample) in samples.iter().enumerate() {
if sample.len() != n_dims {
return Err(Error::DimensionMismatch {
expected: n_dims,
got: sample.len(),
sample_index: i,
});
}
}
let bandwidths = Self::scotts_rule_multivariate(&samples, n_dims);
Ok(Self {
samples,
bandwidths,
n_dims,
})
}
#[cfg(test)]
pub(crate) fn with_bandwidths(samples: Vec<Vec<f64>>, bandwidths: Vec<f64>) -> Result<Self> {
if samples.is_empty() {
return Err(Error::EmptySamples);
}
let n_dims = samples[0].len();
if n_dims == 0 {
return Err(Error::ZeroDimensions);
}
for (i, sample) in samples.iter().enumerate() {
if sample.len() != n_dims {
return Err(Error::DimensionMismatch {
expected: n_dims,
got: sample.len(),
sample_index: i,
});
}
}
if bandwidths.len() != n_dims {
return Err(Error::BandwidthDimensionMismatch {
expected: n_dims,
got: bandwidths.len(),
});
}
for &bw in &bandwidths {
if bw <= 0.0 {
return Err(Error::InvalidBandwidth(bw));
}
}
Ok(Self {
samples,
bandwidths,
n_dims,
})
}
#[allow(clippy::cast_precision_loss)]
fn scotts_rule_multivariate(samples: &[Vec<f64>], n_dims: usize) -> Vec<f64> {
let n = samples.len() as f64;
let d = n_dims as f64;
let exponent = -1.0 / (d + 4.0);
let scale_factor = n.powf(exponent);
(0..n_dims)
.map(|dim| {
let std_dev = Self::dimension_std_dev(samples, dim);
if std_dev < f64::EPSILON {
1.0
} else {
scale_factor * std_dev
}
})
.collect()
}
#[allow(clippy::cast_precision_loss)]
fn dimension_std_dev(samples: &[Vec<f64>], dim: usize) -> f64 {
let n = samples.len() as f64;
let values: Vec<f64> = samples.iter().map(|s| s[dim]).collect();
let mean = values.iter().sum::<f64>() / n;
let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
variance.sqrt()
}
#[cfg(test)]
pub(crate) fn n_dims(&self) -> usize {
self.n_dims
}
#[cfg(test)]
pub(crate) fn n_samples(&self) -> usize {
self.samples.len()
}
#[cfg(test)]
pub(crate) fn bandwidths(&self) -> &[f64] {
&self.bandwidths
}
#[cfg(test)]
pub(crate) fn samples(&self) -> &[Vec<f64>] {
&self.samples
}
#[allow(clippy::cast_precision_loss)]
pub(crate) fn log_pdf(&self, x: &[f64]) -> f64 {
assert_eq!(
x.len(),
self.n_dims,
"Point dimension {} doesn't match KDE dimension {}",
x.len(),
self.n_dims
);
let n = self.samples.len() as f64;
let log_2pi = (2.0 * core::f64::consts::PI).ln();
let log_norm_per_dim: Vec<f64> = self
.bandwidths
.iter()
.map(|&h| -h.ln() - 0.5 * log_2pi)
.collect();
let log_kernels: Vec<f64> = self
.samples
.iter()
.map(|sample| {
let mut log_kernel_sum = 0.0;
for j in 0..self.n_dims {
let z = (x[j] - sample[j]) / self.bandwidths[j];
log_kernel_sum += log_norm_per_dim[j] - 0.5 * z * z;
}
log_kernel_sum
})
.collect();
let max_log_kernel = log_kernels
.iter()
.copied()
.fold(f64::NEG_INFINITY, f64::max);
if max_log_kernel.is_infinite() && max_log_kernel < 0.0 {
return f64::NEG_INFINITY;
}
let sum_exp: f64 = log_kernels
.iter()
.map(|&lk| (lk - max_log_kernel).exp())
.sum();
-n.ln() + max_log_kernel + sum_exp.ln()
}
#[cfg(test)]
pub(crate) fn pdf(&self, x: &[f64]) -> f64 {
self.log_pdf(x).exp()
}
pub(crate) fn sample(&self, rng: &mut fastrand::Rng) -> Vec<f64> {
let idx = rng.usize(0..self.samples.len());
let center = &self.samples[idx];
center
.iter()
.zip(self.bandwidths.iter())
.map(|(¢er_j, &bandwidth_j)| {
let u1: f64 = rng.f64();
let u2: f64 = rng.f64();
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * core::f64::consts::PI * u2).cos();
center_j + z * bandwidth_j
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multivariate_kde_new_basic() {
let samples = vec![vec![1.0, 2.0], vec![1.5, 2.5], vec![2.0, 3.0]];
let kde = MultivariateKDE::new(samples).unwrap();
assert_eq!(kde.n_dims(), 2);
assert_eq!(kde.n_samples(), 3);
assert_eq!(kde.bandwidths().len(), 2);
}
#[test]
fn test_multivariate_kde_new_single_sample() {
let samples = vec![vec![1.0, 2.0, 3.0]];
let kde = MultivariateKDE::new(samples).unwrap();
assert_eq!(kde.n_dims(), 3);
assert_eq!(kde.n_samples(), 1);
for &bw in kde.bandwidths() {
assert!((bw - 1.0).abs() < f64::EPSILON);
}
}
#[test]
fn test_multivariate_kde_new_single_dimension() {
let samples = vec![vec![1.0], vec![2.0], vec![3.0], vec![4.0], vec![5.0]];
let kde = MultivariateKDE::new(samples).unwrap();
assert_eq!(kde.n_dims(), 1);
assert_eq!(kde.n_samples(), 5);
assert_eq!(kde.bandwidths().len(), 1);
assert!(kde.bandwidths()[0] > 0.0);
}
#[test]
fn test_multivariate_kde_scotts_rule() {
let samples: Vec<Vec<f64>> = (0..10)
.map(|i| {
let x = f64::from(i);
vec![x, x * 2.0] })
.collect();
let kde = MultivariateKDE::new(samples).unwrap();
let bw = kde.bandwidths();
assert!(
bw[0] > 1.0 && bw[0] < 3.0,
"First bandwidth {} unexpected",
bw[0]
);
assert!(
bw[1] > 2.0 && bw[1] < 6.0,
"Second bandwidth {} unexpected",
bw[1]
);
assert!(
(bw[1] / bw[0] - 2.0).abs() < 0.1,
"Ratio {} not close to 2",
bw[1] / bw[0]
);
}
#[test]
fn test_multivariate_kde_with_bandwidths() {
let samples = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let bandwidths = vec![0.5, 1.0];
let kde = MultivariateKDE::with_bandwidths(samples, bandwidths).unwrap();
assert_eq!(kde.n_dims(), 2);
assert!((kde.bandwidths()[0] - 0.5).abs() < f64::EPSILON);
assert!((kde.bandwidths()[1] - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_multivariate_kde_empty_samples() {
let samples: Vec<Vec<f64>> = vec![];
let result = MultivariateKDE::new(samples);
assert!(matches!(result, Err(Error::EmptySamples)));
}
#[test]
fn test_multivariate_kde_zero_dimensions() {
let samples = vec![vec![], vec![]];
let result = MultivariateKDE::new(samples);
assert!(matches!(result, Err(Error::ZeroDimensions)));
}
#[test]
fn test_multivariate_kde_dimension_mismatch() {
let samples = vec![vec![1.0, 2.0], vec![3.0]]; let result = MultivariateKDE::new(samples);
assert!(matches!(
result,
Err(Error::DimensionMismatch {
expected: 2,
got: 1,
sample_index: 1
})
));
}
#[test]
fn test_multivariate_kde_with_bandwidths_wrong_length() {
let samples = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let bandwidths = vec![0.5]; let result = MultivariateKDE::with_bandwidths(samples, bandwidths);
assert!(matches!(
result,
Err(Error::BandwidthDimensionMismatch {
expected: 2,
got: 1
})
));
}
#[test]
fn test_multivariate_kde_with_bandwidths_zero() {
let samples = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let bandwidths = vec![0.5, 0.0]; let result = MultivariateKDE::with_bandwidths(samples, bandwidths);
assert!(matches!(result, Err(Error::InvalidBandwidth(bw)) if bw == 0.0));
}
#[test]
fn test_multivariate_kde_with_bandwidths_negative() {
let samples = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let bandwidths = vec![0.5, -1.0]; let result = MultivariateKDE::with_bandwidths(samples, bandwidths);
assert!(
matches!(result, Err(Error::InvalidBandwidth(bw)) if (bw - (-1.0)).abs() < f64::EPSILON)
);
}
#[test]
fn test_multivariate_kde_identical_samples() {
let samples = vec![vec![5.0, 10.0], vec![5.0, 10.0], vec![5.0, 10.0]];
let kde = MultivariateKDE::new(samples).unwrap();
for &bw in kde.bandwidths() {
assert!((bw - 1.0).abs() < f64::EPSILON);
}
}
#[test]
fn test_multivariate_kde_high_dimensional() {
let samples: Vec<Vec<f64>> = (0..20)
.map(|i| {
let x = f64::from(i);
vec![x, x * 0.5, x * 2.0, x * 0.1, x * 10.0]
})
.collect();
let kde = MultivariateKDE::new(samples).unwrap();
assert_eq!(kde.n_dims(), 5);
assert_eq!(kde.n_samples(), 20);
assert_eq!(kde.bandwidths().len(), 5);
for &bw in kde.bandwidths() {
assert!(bw > 0.0);
}
}
#[test]
fn test_multivariate_kde_samples_accessor() {
let samples = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let kde = MultivariateKDE::new(samples.clone()).unwrap();
assert_eq!(kde.samples(), &samples);
}
#[test]
fn test_multivariate_kde_pdf_basic() {
let samples = vec![vec![0.0, 0.0], vec![1.0, 1.0], vec![2.0, 2.0]];
let kde = MultivariateKDE::new(samples).unwrap();
assert!(kde.pdf(&[0.0, 0.0]) > 0.0);
assert!(kde.pdf(&[1.0, 1.0]) > 0.0);
assert!(kde.pdf(&[2.0, 2.0]) > 0.0);
assert!(kde.pdf(&[0.5, 0.5]) > 0.0);
let near_density = kde.pdf(&[1.0, 1.0]);
let far_density = kde.pdf(&[10.0, 10.0]);
assert!(near_density > far_density);
}
#[test]
fn test_multivariate_kde_pdf_with_custom_bandwidths() {
let samples = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
let bandwidths = vec![0.5, 0.5];
let kde = MultivariateKDE::with_bandwidths(samples, bandwidths).unwrap();
assert!(kde.pdf(&[0.5, 0.5]) > 0.0);
assert!(kde.pdf(&[0.0, 0.0]) > 0.0);
}
#[test]
fn test_multivariate_kde_pdf_single_sample() {
let samples = vec![vec![5.0, 10.0]];
let kde = MultivariateKDE::new(samples).unwrap();
assert!(kde.pdf(&[5.0, 10.0]) > 0.0);
assert!(kde.pdf(&[4.5, 9.5]) > 0.0);
}
#[test]
fn test_multivariate_kde_log_pdf_consistency() {
let samples = vec![vec![0.0, 0.0], vec![1.0, 1.0], vec![2.0, 2.0]];
let kde = MultivariateKDE::new(samples).unwrap();
let test_points = vec![
vec![0.0, 0.0],
vec![1.0, 1.0],
vec![0.5, 0.5],
vec![3.0, 3.0],
];
for point in test_points {
let log_p = kde.log_pdf(&point);
let p = kde.pdf(&point);
let p_from_log = log_p.exp();
assert!(
(p - p_from_log).abs() < 1e-10,
"pdf={p}, exp(log_pdf)={p_from_log}"
);
}
}
#[test]
fn test_multivariate_kde_pdf_integrates_to_one_1d() {
let samples = vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]];
let kde = MultivariateKDE::new(samples).unwrap();
let n_points = 1000;
let low = -10.0;
let high = 15.0;
let dx = (high - low) / f64::from(n_points);
let integral: f64 = (0..n_points)
.map(|i| {
let x = low + (f64::from(i) + 0.5) * dx;
kde.pdf(&[x]) * dx
})
.sum();
assert!(
(integral - 1.0).abs() < 0.02,
"1D integral = {integral}, expected ~1.0"
);
}
#[test]
fn test_multivariate_kde_pdf_integrates_to_one_2d() {
let samples = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
vec![0.5, 0.5],
];
let kde = MultivariateKDE::new(samples).unwrap();
let n_points = 100; let low = -5.0;
let high = 6.0;
let dx = (high - low) / f64::from(n_points);
let mut integral = 0.0;
for i in 0..n_points {
for j in 0..n_points {
let x = low + (f64::from(i) + 0.5) * dx;
let y = low + (f64::from(j) + 0.5) * dx;
integral += kde.pdf(&[x, y]) * dx * dx;
}
}
assert!(
(integral - 1.0).abs() < 0.05,
"2D integral = {integral}, expected ~1.0"
);
}
#[test]
fn test_multivariate_kde_pdf_symmetry() {
let samples = vec![
vec![1.0, 0.0],
vec![-1.0, 0.0],
vec![0.0, 1.0],
vec![0.0, -1.0],
];
let kde = MultivariateKDE::new(samples).unwrap();
let d1 = kde.pdf(&[0.5, 0.0]);
let d2 = kde.pdf(&[-0.5, 0.0]);
assert!(
(d1 - d2).abs() < 1e-10,
"Symmetric points have different densities: {d1} vs {d2}"
);
let d3 = kde.pdf(&[0.0, 0.5]);
let d4 = kde.pdf(&[0.0, -0.5]);
assert!(
(d3 - d4).abs() < 1e-10,
"Symmetric points have different densities: {d3} vs {d4}"
);
}
#[test]
fn test_multivariate_kde_pdf_high_dimensional() {
let samples: Vec<Vec<f64>> = (0..10)
.map(|i| {
let x = f64::from(i) * 0.1;
vec![x, x, x, x, x] })
.collect();
let kde = MultivariateKDE::new(samples).unwrap();
assert!(kde.pdf(&[0.5, 0.5, 0.5, 0.5, 0.5]) > 0.0);
assert!(kde.pdf(&[0.0, 0.0, 0.0, 0.0, 0.0]) > 0.0);
let log_p = kde.log_pdf(&[0.5, 0.5, 0.5, 0.5, 0.5]);
assert!(log_p.is_finite());
}
#[test]
fn test_multivariate_kde_pdf_numerical_stability() {
let samples = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
let kde = MultivariateKDE::new(samples).unwrap();
let far_pdf = kde.pdf(&[100.0, 100.0]);
assert!(far_pdf >= 0.0);
assert!(far_pdf.is_finite() || far_pdf == 0.0);
let far_log_pdf = kde.log_pdf(&[100.0, 100.0]);
assert!(far_log_pdf.is_finite() || far_log_pdf.is_infinite());
}
#[test]
#[should_panic(expected = "Point dimension")]
fn test_multivariate_kde_pdf_wrong_dimension() {
let samples = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
let kde = MultivariateKDE::new(samples).unwrap();
let _ = kde.pdf(&[0.0]); }
#[test]
fn test_multivariate_kde_sample_basic() {
let samples = vec![vec![0.0, 0.0], vec![1.0, 1.0], vec![2.0, 2.0]];
let kde = MultivariateKDE::new(samples).unwrap();
let mut rng = fastrand::Rng::new();
let sample = kde.sample(&mut rng);
assert_eq!(sample.len(), 2);
}
#[test]
fn test_multivariate_kde_sample_in_reasonable_range() {
let samples = vec![
vec![0.0, 0.0],
vec![1.0, 1.0],
vec![2.0, 2.0],
vec![3.0, 3.0],
vec![4.0, 4.0],
];
let kde = MultivariateKDE::new(samples).unwrap();
let mut rng = fastrand::Rng::new();
for _ in 0..100 {
let s = kde.sample(&mut rng);
assert!(
s[0] > -10.0 && s[0] < 15.0,
"Sample dimension 0: {} outside expected range",
s[0]
);
assert!(
s[1] > -10.0 && s[1] < 15.0,
"Sample dimension 1: {} outside expected range",
s[1]
);
}
}
#[test]
fn test_multivariate_kde_sample_single_sample() {
let samples = vec![vec![5.0, 10.0]];
let kde = MultivariateKDE::new(samples).unwrap();
let mut rng = fastrand::Rng::new();
let n_samples = 100;
let mut sum_x = 0.0;
let mut sum_y = 0.0;
for _ in 0..n_samples {
let s = kde.sample(&mut rng);
sum_x += s[0];
sum_y += s[1];
}
let mean_x = sum_x / f64::from(n_samples);
let mean_y = sum_y / f64::from(n_samples);
assert!(
(mean_x - 5.0).abs() < 1.0,
"Mean x={mean_x}, expected close to 5.0"
);
assert!(
(mean_y - 10.0).abs() < 1.0,
"Mean y={mean_y}, expected close to 10.0"
);
}
#[test]
fn test_multivariate_kde_sample_high_dimensional() {
let samples: Vec<Vec<f64>> = (0..10)
.map(|i| {
let x = f64::from(i) * 0.5;
vec![x, x * 2.0, x * 0.5, x + 1.0, x - 1.0] })
.collect();
let kde = MultivariateKDE::new(samples).unwrap();
let mut rng = fastrand::Rng::new();
for _ in 0..50 {
let sample = kde.sample(&mut rng);
assert_eq!(sample.len(), 5);
for &val in &sample {
assert!(val.is_finite(), "Sample value is not finite: {val}");
}
}
}
#[test]
#[allow(clippy::cast_precision_loss)]
fn test_multivariate_kde_sample_respects_bandwidth() {
let data = vec![vec![0.0, 0.0], vec![0.0, 0.0], vec![0.0, 0.0]];
let bandwidths = vec![0.1, 10.0]; let kde = MultivariateKDE::with_bandwidths(data, bandwidths).unwrap();
let mut rng = fastrand::Rng::new();
let n_samples = 1000;
let mut values_x: Vec<f64> = Vec::with_capacity(n_samples);
let mut values_y: Vec<f64> = Vec::with_capacity(n_samples);
for _ in 0..n_samples {
let s = kde.sample(&mut rng);
values_x.push(s[0]);
values_y.push(s[1]);
}
let n = n_samples as f64;
let mean_x: f64 = values_x.iter().sum::<f64>() / n;
let mean_y: f64 = values_y.iter().sum::<f64>() / n;
let var_x: f64 = values_x.iter().map(|x| (x - mean_x).powi(2)).sum::<f64>() / n;
let var_y: f64 = values_y.iter().map(|y| (y - mean_y).powi(2)).sum::<f64>() / n;
assert!(
var_x < 0.05,
"X variance {var_x} too large for bandwidth 0.1"
);
assert!(
var_y > 50.0 && var_y < 200.0,
"Y variance {var_y} unexpected for bandwidth 10.0"
);
}
#[test]
fn test_multivariate_kde_sample_distribution_shape() {
let data = vec![
vec![0.0, 0.0],
vec![1.0, 1.0],
vec![2.0, 2.0],
vec![3.0, 3.0],
vec![4.0, 4.0],
];
let kde = MultivariateKDE::new(data).unwrap();
let mut rng = fastrand::Rng::new();
let n_samples = 500;
let mut sum = [0.0, 0.0];
for _ in 0..n_samples {
let s = kde.sample(&mut rng);
sum[0] += s[0];
sum[1] += s[1];
}
let mean_x = sum[0] / f64::from(n_samples);
let mean_y = sum[1] / f64::from(n_samples);
assert!(
(mean_x - 2.0).abs() < 0.5,
"Mean x={mean_x}, expected close to 2.0"
);
assert!(
(mean_y - 2.0).abs() < 0.5,
"Mean y={mean_y}, expected close to 2.0"
);
}
#[test]
fn test_multivariate_kde_sample_deterministic_with_seeded_rng() {
let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let kde = MultivariateKDE::new(data).unwrap();
let mut rng1 = fastrand::Rng::with_seed(42);
let mut rng2 = fastrand::Rng::with_seed(42);
let result1 = kde.sample(&mut rng1);
let result2 = kde.sample(&mut rng2);
assert!(
(result1[0] - result2[0]).abs() < f64::EPSILON,
"Samples with same seed differ in dimension 0"
);
assert!(
(result1[1] - result2[1]).abs() < f64::EPSILON,
"Samples with same seed differ in dimension 1"
);
}
}