use std::collections::HashSet;
use crate::math::{DistanceMetric, FloatNumber, Point};
#[derive(Debug, Default, PartialEq)]
pub enum SamplingStrategy<T>
where
T: FloatNumber,
{
#[default]
FarthestPointSampling,
WeightedFarthestPointSampling(Vec<T>),
}
impl<T> SamplingStrategy<T>
where
T: FloatNumber,
{
pub fn sample<const N: usize>(&self, points: &[Point<T, N>], n: usize) -> HashSet<usize> {
if n == 0 || points.is_empty() {
return HashSet::new();
}
if points.len() <= n {
return (0..points.len()).collect();
}
let metric = DistanceMetric::SquaredEuclidean;
match self {
SamplingStrategy::FarthestPointSampling => {
sample_with_distance_fn(points, n, 0, |_, point1, point2| {
metric.measure(point1, point2)
})
}
SamplingStrategy::WeightedFarthestPointSampling(weights) => {
debug_assert_eq!(
points.len(),
weights.len(),
"The number of points and weights must be equal."
);
let (initial_index, _) = weights
.iter()
.enumerate()
.max_by(|(_, weight1), (_, weight2)| weight1.partial_cmp(weight2).unwrap())
.unwrap();
sample_with_distance_fn(points, n, initial_index, |index, point1, point2| {
metric.measure(point1, point2) * weights[index]
})
}
}
}
}
#[must_use]
fn sample_with_distance_fn<T, const N: usize, F>(
points: &[Point<T, N>],
n: usize,
initial_index: usize,
distance_fn: F,
) -> HashSet<usize>
where
T: FloatNumber,
F: Fn(usize, &Point<T, N>, &Point<T, N>) -> T,
{
let mut selected = HashSet::with_capacity(n);
selected.insert(initial_index);
let mut distances = vec![T::infinity(); points.len()];
let initial_point = &points[initial_index];
update_distances(
points,
&mut distances,
&selected,
initial_point,
&distance_fn,
);
while selected.len() < n {
let farthest_index = find_farthest_index(&distances, &selected);
if !selected.insert(farthest_index) {
break;
}
let farthest_point = &points[farthest_index];
update_distances(
points,
&mut distances,
&selected,
farthest_point,
&distance_fn,
);
}
selected
}
#[inline]
#[must_use]
fn find_farthest_index<T>(distances: &[T], selected: &HashSet<usize>) -> usize
where
T: FloatNumber,
{
let mut farthest_index = 0;
let mut farthest_distance = T::zero();
for (index, &distance) in distances.iter().enumerate() {
if selected.contains(&index) {
continue;
}
if distance > farthest_distance {
farthest_index = index;
farthest_distance = distance;
}
}
farthest_index
}
#[inline]
fn update_distances<T, const N: usize, F>(
points: &[Point<T, N>],
distances: &mut [T],
selected: &HashSet<usize>,
farthest_point: &Point<T, N>,
distance_fn: &F,
) where
T: FloatNumber,
F: Fn(usize, &Point<T, N>, &Point<T, N>) -> T,
{
for (index, point) in points.iter().enumerate() {
if selected.contains(&index) {
distances[index] = T::zero();
continue;
}
let distance = distance_fn(index, point, farthest_point);
distances[index] = distances[index].min(distance);
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
#[must_use]
fn sample_points() -> Vec<Point<f32, 2>> {
vec![
[0.0, 0.0], [0.1, 0.1], [0.1, 0.2], [0.2, 0.2], [0.2, 0.4], [0.3, 0.5], [0.1, 0.0], [0.0, 0.1], [0.0, 0.2], ]
}
#[must_use]
fn empty_points() -> Vec<Point<f32, 2>> {
vec![]
}
#[rstest]
#[case(0, vec ! [])]
#[case(1, vec ! [0])]
#[case(3, vec ! [0, 3, 5])]
#[case(5, vec ! [0, 1, 3, 5, 8])]
#[case(9, vec ! [0, 1, 2, 3, 4, 5, 6, 7, 8])]
#[case(10, vec ! [0, 1, 2, 3, 4, 5, 6, 7, 8])]
fn test_sample_farthest_point_sampling(#[case] n: usize, #[case] expected: Vec<usize>) {
let points = sample_points();
let sampling = SamplingStrategy::FarthestPointSampling;
let sampled = sampling.sample(&points, n);
assert_eq!(sampled, expected.into_iter().collect());
}
#[test]
fn test_sample_farthest_point_sampling_empty() {
let points = empty_points();
let sampling = SamplingStrategy::FarthestPointSampling;
let sampled = sampling.sample(&points, 2);
assert!(sampled.is_empty());
}
#[rstest]
#[case(0, vec ! [])]
#[case(1, vec ! [8])]
#[case(3, vec ! [5,6, 8])]
#[case(5, vec ! [3, 5, 6, 7, 8])]
#[case(9, vec ! [0, 1, 2, 3, 4, 5, 6, 7, 8])]
#[case(10, vec ! [0, 1, 2, 3, 4, 5, 6, 7, 8])]
fn test_sample_weighted_farthest_point_sampling(
#[case] n: usize,
#[case] expected: Vec<usize>,
) {
let weights = vec![1.0, 1.0, 2.0, 3.0, 5.0, 8.0, 13.0, 21.0, 34.0];
let sampling = SamplingStrategy::WeightedFarthestPointSampling(weights);
let points = sample_points();
let actual = sampling.sample(&points, n);
assert_eq!(actual, expected.into_iter().collect());
}
#[test]
fn test_sample_weighted_farthest_point_sampling_empty() {
let weights = vec![];
let sampling = SamplingStrategy::WeightedFarthestPointSampling(weights);
let points = empty_points();
let actual = sampling.sample(&points, 2);
assert!(actual.is_empty());
}
#[cfg(debug_assertions)]
#[test]
#[should_panic]
fn test_sample_weighted_farthest_point_sampling_invalid() {
let weights = vec![1.0, 2.0];
let sampling = SamplingStrategy::WeightedFarthestPointSampling(weights);
let points = sample_points();
let _ = sampling.sample(&points, 2);
}
}