use nalgebra::RealField;
use num_traits::Float;
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
use crate::types::gaussian::{GaussianMixture, GaussianState};
use crate::types::spaces::{StateCovariance, StateVector};
#[derive(Debug, Clone)]
pub struct PruningConfig<T: RealField> {
pub weight_threshold: T,
pub merge_threshold: T,
pub max_components: usize,
}
impl<T: RealField + Float> PruningConfig<T> {
pub fn default_config() -> Self {
Self {
weight_threshold: T::from_f64(1e-5).unwrap(),
merge_threshold: T::from_f64(4.0).unwrap(),
max_components: 100,
}
}
pub fn new(weight_threshold: T, merge_threshold: T, max_components: usize) -> Self {
Self {
weight_threshold,
merge_threshold,
max_components,
}
}
}
#[cfg(feature = "alloc")]
pub fn prune_by_weight<T: RealField + Float + Copy, const N: usize>(
mixture: &GaussianMixture<T, N>,
threshold: T,
) -> GaussianMixture<T, N> {
let pruned_weight_sum: T = mixture
.iter()
.filter(|c| c.weight < threshold)
.fold(T::zero(), |acc, c| acc + c.weight);
let mut remaining: Vec<_> = mixture
.iter()
.filter(|c| c.weight >= threshold)
.cloned()
.collect();
if !remaining.is_empty() && pruned_weight_sum > T::zero() {
let n_remaining = T::from(remaining.len()).unwrap();
let weight_per_component = pruned_weight_sum / n_remaining;
for component in &mut remaining {
component.weight += weight_per_component;
}
}
GaussianMixture::from_components(remaining)
}
#[cfg(feature = "alloc")]
pub fn truncate<T: RealField + Float + Copy, const N: usize>(
mixture: &GaussianMixture<T, N>,
max_components: usize,
) -> GaussianMixture<T, N> {
if mixture.len() <= max_components {
return mixture.clone();
}
let mut indexed: Vec<_> = mixture.iter().enumerate().collect();
indexed.sort_by(|(_, a), (_, b)| {
b.weight
.partial_cmp(&a.weight)
.unwrap_or(core::cmp::Ordering::Equal)
});
let truncated_weight_sum: T = indexed
.iter()
.skip(max_components)
.fold(T::zero(), |acc, (_, c)| acc + c.weight);
let mut remaining: Vec<_> = indexed
.into_iter()
.take(max_components)
.map(|(_, c)| c.clone())
.collect();
if !remaining.is_empty() && truncated_weight_sum > T::zero() {
let n_remaining = T::from(remaining.len()).unwrap();
let weight_per_component = truncated_weight_sum / n_remaining;
for component in &mut remaining {
component.weight += weight_per_component;
}
}
GaussianMixture::from_components(remaining)
}
#[cfg(feature = "alloc")]
pub fn mahalanobis_distance_squared<T: RealField + Float + Copy, const N: usize>(
a: &GaussianState<T, N>,
b: &GaussianState<T, N>,
) -> T {
let diff = StateVector::from_svector(a.mean.as_svector() - b.mean.as_svector());
if let Some(cov_inv) = a.covariance.try_inverse() {
let d = diff.as_svector();
let m = cov_inv.as_matrix();
(d.transpose() * m * d)[(0, 0)]
} else {
T::infinity()
}
}
#[cfg(feature = "alloc")]
pub fn merge_components<T: RealField + Float + Copy, const N: usize>(
a: &GaussianState<T, N>,
b: &GaussianState<T, N>,
) -> Option<GaussianState<T, N>> {
let w_sum = a.weight + b.weight;
if w_sum <= T::zero() {
return None;
}
let mean = StateVector::from_svector(
(a.mean.as_svector().scale(a.weight) + b.mean.as_svector().scale(b.weight))
.scale(T::one() / w_sum),
);
let diff_a = StateVector::from_svector(a.mean.as_svector() - mean.as_svector());
let diff_b = StateVector::from_svector(b.mean.as_svector() - mean.as_svector());
let spread_a = diff_a.as_svector() * diff_a.as_svector().transpose();
let spread_b = diff_b.as_svector() * diff_b.as_svector().transpose();
let merged_cov = StateCovariance::from_matrix(
(a.covariance.as_matrix().scale(a.weight)
+ b.covariance.as_matrix().scale(b.weight)
+ spread_a.scale(a.weight)
+ spread_b.scale(b.weight))
.scale(T::one() / w_sum),
);
Some(GaussianState::new(w_sum, mean, merged_cov))
}
#[cfg(feature = "alloc")]
pub fn merge_nearby<T: RealField + Float + Copy, const N: usize>(
mixture: &GaussianMixture<T, N>,
threshold: T,
) -> GaussianMixture<T, N> {
if mixture.is_empty() {
return GaussianMixture::new();
}
let threshold_sq = threshold * threshold;
let mut remaining: Vec<_> = mixture.iter().cloned().collect();
let mut merged = GaussianMixture::new();
while !remaining.is_empty() {
let max_idx = remaining
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.weight
.partial_cmp(&b.weight)
.unwrap_or(core::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap();
let mut current = remaining.remove(max_idx);
let mut i = 0;
while i < remaining.len() {
let dist_sq = mahalanobis_distance_squared(¤t, &remaining[i]);
if dist_sq < threshold_sq {
let to_merge = remaining.remove(i);
if let Some(merged_component) = merge_components(¤t, &to_merge) {
current = merged_component;
}
} else {
i += 1;
}
}
merged.push(current);
}
merged
}
#[cfg(feature = "alloc")]
pub fn prune_and_merge<T: RealField + Float + Copy, const N: usize>(
mixture: &GaussianMixture<T, N>,
config: &PruningConfig<T>,
) -> GaussianMixture<T, N> {
let pruned = prune_by_weight(mixture, config.weight_threshold);
let merged = merge_nearby(&pruned, config.merge_threshold);
truncate(&merged, config.max_components)
}
#[cfg(feature = "alloc")]
pub fn normalize_weights<T: RealField + Float + Copy, const N: usize>(
mixture: &mut GaussianMixture<T, N>,
target_sum: T,
) {
let total = mixture.total_weight();
if total > T::zero() {
let scale = target_sum / total;
mixture.scale_weights(scale);
}
}
#[cfg(feature = "alloc")]
pub fn cap_total_weight<T: RealField + Float + Copy, const N: usize>(
mixture: &mut GaussianMixture<T, N>,
max_weight: T,
) {
let total = mixture.total_weight();
if total > max_weight {
let scale = max_weight / total;
mixture.scale_weights(scale);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "alloc")]
fn make_component(weight: f64, x: f64, y: f64) -> GaussianState<f64, 4> {
GaussianState::new(
weight,
StateVector::from_array([x, y, 0.0, 0.0]),
StateCovariance::identity(),
)
}
#[cfg(feature = "alloc")]
#[test]
fn test_prune_by_weight() {
let mut mixture = GaussianMixture::new();
mixture.push(make_component(0.5, 0.0, 0.0));
mixture.push(make_component(0.001, 10.0, 10.0));
mixture.push(make_component(0.3, 20.0, 20.0));
let original_total = mixture.total_weight();
let pruned = prune_by_weight(&mixture, 0.01);
assert_eq!(pruned.len(), 2);
assert!((pruned.total_weight() - original_total).abs() < 1e-10);
}
#[cfg(feature = "alloc")]
#[test]
fn test_truncate() {
let mut mixture = GaussianMixture::new();
for i in 0..10 {
mixture.push(make_component(i as f64 * 0.1, i as f64, 0.0));
}
let original_total = mixture.total_weight();
let truncated = truncate(&mixture, 3);
assert_eq!(truncated.len(), 3);
assert!((truncated.total_weight() - original_total).abs() < 1e-10);
}
#[cfg(feature = "alloc")]
#[test]
fn test_merge_components() {
let a = make_component(0.6, 0.0, 0.0);
let b = make_component(0.4, 2.0, 0.0);
let merged = merge_components(&a, &b).expect("merge should succeed with positive weights");
assert!((merged.weight - 1.0).abs() < 1e-10);
assert!((merged.mean.index(0) - 0.8).abs() < 1e-10);
}
#[cfg(feature = "alloc")]
#[test]
fn test_merge_components_zero_weights() {
let a = make_component(0.0, 0.0, 0.0);
let b = make_component(0.0, 2.0, 0.0);
assert!(merge_components(&a, &b).is_none());
}
#[cfg(feature = "alloc")]
#[test]
fn test_merge_nearby() {
let mut mixture = GaussianMixture::new();
mixture.push(make_component(0.5, 0.0, 0.0));
mixture.push(make_component(0.3, 0.5, 0.0));
mixture.push(make_component(0.2, 100.0, 100.0));
let merged = merge_nearby(&mixture, 2.0);
assert_eq!(merged.len(), 2);
assert!((merged.total_weight() - 1.0).abs() < 1e-10);
}
#[cfg(feature = "alloc")]
#[test]
fn test_prune_and_merge() {
let mut mixture = GaussianMixture::new();
mixture.push(make_component(0.5, 0.0, 0.0));
mixture.push(make_component(0.3, 0.5, 0.0));
mixture.push(make_component(0.0001, 50.0, 50.0)); mixture.push(make_component(0.2, 100.0, 100.0));
let config = PruningConfig::new(0.001, 2.0, 10);
let result = prune_and_merge(&mixture, &config);
assert_eq!(result.len(), 2);
}
}