use nalgebra::RealField;
use num_traits::Float;
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
use crate::types::gaussian::{GaussianMixture, GaussianState};
use crate::types::spaces::{StateCovariance, StateVector};
#[derive(Debug, Clone)]
pub struct TargetEstimate<T: RealField, const N: usize> {
pub state: StateVector<T, N>,
pub covariance: StateCovariance<T, N>,
pub confidence: T,
}
impl<T: RealField + Copy, const N: usize> TargetEstimate<T, N> {
pub fn new(state: StateVector<T, N>, covariance: StateCovariance<T, N>, confidence: T) -> Self {
Self {
state,
covariance,
confidence,
}
}
pub fn from_gaussian(gaussian: &GaussianState<T, N>) -> Self {
Self {
state: gaussian.mean,
covariance: gaussian.covariance,
confidence: gaussian.weight,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum ExtractionMethod {
WeightThreshold,
TopN,
ExpectedCount,
LocalMaxima,
}
#[derive(Debug, Clone)]
pub struct ExtractionConfig<T: RealField> {
pub method: ExtractionMethod,
pub weight_threshold: T,
pub top_n: usize,
pub local_maxima_dist_sq: T,
}
impl<T: RealField + Float> ExtractionConfig<T> {
pub fn weight_threshold(threshold: T) -> Self {
Self {
method: ExtractionMethod::WeightThreshold,
weight_threshold: threshold,
top_n: 0,
local_maxima_dist_sq: T::from_f64(4.0).unwrap(),
}
}
pub fn top_n(n: usize) -> Self {
Self {
method: ExtractionMethod::TopN,
weight_threshold: T::zero(),
top_n: n,
local_maxima_dist_sq: T::from_f64(4.0).unwrap(),
}
}
pub fn expected_count() -> Self {
Self {
method: ExtractionMethod::ExpectedCount,
weight_threshold: T::from_f64(0.5).unwrap(),
top_n: 0,
local_maxima_dist_sq: T::from_f64(4.0).unwrap(),
}
}
pub fn local_maxima(min_weight: T, merge_dist_sq: T) -> Self {
Self {
method: ExtractionMethod::LocalMaxima,
weight_threshold: min_weight,
top_n: 0,
local_maxima_dist_sq: merge_dist_sq,
}
}
}
#[cfg(feature = "alloc")]
pub fn extract_targets<T: RealField + Float + Copy, const N: usize>(
mixture: &GaussianMixture<T, N>,
config: &ExtractionConfig<T>,
) -> Vec<TargetEstimate<T, N>> {
match config.method {
ExtractionMethod::WeightThreshold => extract_by_threshold(mixture, config.weight_threshold),
ExtractionMethod::TopN => extract_top_n(mixture, config.top_n),
ExtractionMethod::ExpectedCount => {
let n = num_traits::Float::round(mixture.total_weight())
.to_usize()
.unwrap_or(0);
extract_top_n(mixture, n)
}
ExtractionMethod::LocalMaxima => extract_local_maxima_with_dist(
mixture,
config.weight_threshold,
config.local_maxima_dist_sq,
),
}
}
#[cfg(feature = "alloc")]
pub fn extract_by_threshold<T: RealField + Float + Copy, const N: usize>(
mixture: &GaussianMixture<T, N>,
threshold: T,
) -> Vec<TargetEstimate<T, N>> {
mixture
.iter()
.filter(|c| c.weight >= threshold)
.map(TargetEstimate::from_gaussian)
.collect()
}
#[cfg(feature = "alloc")]
pub fn extract_top_n<T: RealField + Float + Copy, const N: usize>(
mixture: &GaussianMixture<T, N>,
n: usize,
) -> Vec<TargetEstimate<T, N>> {
if n == 0 {
return Vec::new();
}
let mut indexed: Vec<_> = mixture.iter().enumerate().collect();
indexed.sort_by(|(_, a), (_, b)| {
b.weight
.partial_cmp(&a.weight)
.unwrap_or(core::cmp::Ordering::Equal)
});
indexed
.into_iter()
.take(n)
.map(|(_, c)| TargetEstimate::from_gaussian(c))
.collect()
}
#[cfg(feature = "alloc")]
pub fn extract_local_maxima<T: RealField + Float + Copy, const N: usize>(
mixture: &GaussianMixture<T, N>,
min_weight: T,
) -> Vec<TargetEstimate<T, N>> {
extract_local_maxima_with_dist(mixture, min_weight, T::from_f64(4.0).unwrap())
}
#[cfg(feature = "alloc")]
pub fn extract_local_maxima_with_dist<T: RealField + Float + Copy, const N: usize>(
mixture: &GaussianMixture<T, N>,
min_weight: T,
merge_dist_sq: T,
) -> Vec<TargetEstimate<T, N>> {
use crate::utils::pruning::mahalanobis_distance_squared;
let mut sorted: Vec<_> = mixture
.iter()
.filter(|c| c.weight >= min_weight)
.cloned()
.collect();
sorted.sort_by(|a, b| {
b.weight
.partial_cmp(&a.weight)
.unwrap_or(core::cmp::Ordering::Equal)
});
let mut maxima = Vec::new();
for component in sorted {
let is_local_max = maxima.iter().all(|m: &TargetEstimate<T, N>| {
let temp = GaussianState::new(T::one(), m.state, m.covariance);
mahalanobis_distance_squared(&component, &temp) > merge_dist_sq
});
if is_local_max {
maxima.push(TargetEstimate::from_gaussian(&component));
}
}
maxima
}
#[cfg(feature = "alloc")]
pub fn estimate_cardinality<T: RealField + Float + Copy, const N: usize>(
mixture: &GaussianMixture<T, N>,
) -> usize {
num_traits::Float::round(mixture.total_weight())
.to_usize()
.unwrap_or(0)
}
#[cfg(feature = "alloc")]
pub fn mixture_mean<T: RealField + Float + Copy, const N: usize>(
mixture: &GaussianMixture<T, N>,
) -> Option<StateVector<T, N>> {
if mixture.is_empty() {
return None;
}
let total_weight = mixture.total_weight();
if total_weight <= T::zero() {
return None;
}
let mut sum = nalgebra::SVector::<T, N>::zeros();
for component in mixture.iter() {
sum += component.mean.as_svector().scale(component.weight);
}
Some(StateVector::from_svector(
sum.scale(T::one() / total_weight),
))
}
#[cfg(feature = "alloc")]
pub fn mixture_covariance<T: RealField + Float + Copy, const N: usize>(
mixture: &GaussianMixture<T, N>,
) -> Option<StateCovariance<T, N>> {
let mean = mixture_mean(mixture)?;
let total_weight = mixture.total_weight();
if total_weight <= T::zero() {
return None;
}
let mut cov_sum = nalgebra::SMatrix::<T, N, N>::zeros();
for component in mixture.iter() {
let diff = component.mean.as_svector() - mean.as_svector();
let spread = diff * diff.transpose();
cov_sum += (component.covariance.as_matrix() + spread).scale(component.weight);
}
Some(StateCovariance::from_matrix(
cov_sum.scale(T::one() / total_weight),
))
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "alloc")]
fn make_gaussian(weight: f64, x: f64, y: f64) -> GaussianState<f64, 4> {
GaussianState::new(
weight,
StateVector::from_array([x, y, 0.0, 0.0]),
StateCovariance::identity(),
)
}
#[cfg(feature = "alloc")]
#[test]
fn test_extract_by_threshold() {
let mut mixture = GaussianMixture::new();
mixture.push(make_gaussian(0.8, 0.0, 0.0));
mixture.push(make_gaussian(0.3, 10.0, 10.0));
mixture.push(make_gaussian(0.1, 20.0, 20.0));
let targets = extract_by_threshold(&mixture, 0.5);
assert_eq!(targets.len(), 1);
assert!((targets[0].state.index(0) - 0.0).abs() < 1e-10);
}
#[cfg(feature = "alloc")]
#[test]
fn test_extract_top_n() {
let mut mixture = GaussianMixture::new();
mixture.push(make_gaussian(0.3, 0.0, 0.0));
mixture.push(make_gaussian(0.8, 10.0, 10.0));
mixture.push(make_gaussian(0.5, 20.0, 20.0));
let targets = extract_top_n(&mixture, 2);
assert_eq!(targets.len(), 2);
assert!((targets[0].confidence - 0.8).abs() < 1e-10);
assert!((targets[1].confidence - 0.5).abs() < 1e-10);
}
#[cfg(feature = "alloc")]
#[test]
fn test_estimate_cardinality() {
let mut mixture = GaussianMixture::new();
mixture.push(make_gaussian(0.8, 0.0, 0.0));
mixture.push(make_gaussian(0.7, 10.0, 10.0));
mixture.push(make_gaussian(0.6, 20.0, 20.0));
assert_eq!(estimate_cardinality(&mixture), 2);
}
#[cfg(feature = "alloc")]
#[test]
fn test_mixture_mean() {
let mut mixture = GaussianMixture::new();
mixture.push(make_gaussian(0.5, 0.0, 0.0));
mixture.push(make_gaussian(0.5, 10.0, 0.0));
let mean = mixture_mean(&mixture).unwrap();
assert!((mean.index(0) - 5.0).abs() < 1e-10);
}
#[cfg(feature = "alloc")]
#[test]
fn test_extract_config() {
let mut mixture = GaussianMixture::new();
mixture.push(make_gaussian(0.9, 0.0, 0.0));
mixture.push(make_gaussian(0.1, 10.0, 10.0));
let config = ExtractionConfig::expected_count();
let targets = extract_targets(&mixture, &config);
assert_eq!(targets.len(), 1);
}
}