use std::{cmp::Ordering, collections::HashSet};
use num_traits::clamp;
use crate::math::{DistanceMetric, FloatNumber, Point};
#[derive(Debug, Default, PartialEq)]
pub enum SamplingStrategy<T>
where
T: FloatNumber,
{
#[default]
Farthest,
WeightedFarthest(Vec<T>),
Diversity(T, Vec<T>),
}
impl<T> SamplingStrategy<T>
where
T: FloatNumber,
{
pub fn sample<const N: usize>(&self, points: &[Point<T, N>], n: usize) -> HashSet<usize> {
if n == 0 || points.is_empty() {
return HashSet::new();
}
if points.len() <= n {
return (0..points.len()).collect();
}
match self {
SamplingStrategy::Farthest => sample_farthest_point(points, n),
SamplingStrategy::WeightedFarthest(weights) => {
debug_assert_eq!(
points.len(),
weights.len(),
"The number of points and weights must be equal."
);
sample_weighted_farthest_point(points, weights, n)
}
SamplingStrategy::Diversity(weight, scores) => {
debug_assert_eq!(
points.len(),
scores.len(),
"The number of points and scores must be equal."
);
let normalized_weight = clamp(*weight, T::zero(), T::one());
sample_diversity(points, scores, normalized_weight, n)
}
}
}
}
#[must_use]
fn sample_farthest_point<T, const N: usize>(points: &[Point<T, N>], n: usize) -> HashSet<usize>
where
T: FloatNumber,
{
sample_with_distance_fn(points, n, 0, |_, point1, point2| {
DistanceMetric::SquaredEuclidean.measure(point1, point2)
})
}
#[must_use]
fn sample_weighted_farthest_point<T, const N: usize>(
points: &[Point<T, N>],
weights: &[T],
n: usize,
) -> HashSet<usize>
where
T: FloatNumber,
{
let (initial_index, _) = weights
.iter()
.enumerate()
.max_by(|(_, weight1), (_, weight2)| {
weight1.partial_cmp(weight2).unwrap_or(Ordering::Equal)
})
.unwrap_or((0, &T::zero()));
sample_with_distance_fn(points, n, initial_index, |index, point1, point2| {
DistanceMetric::SquaredEuclidean.measure(point1, point2) * weights[index]
})
}
#[must_use]
fn sample_with_distance_fn<T, const N: usize, F>(
points: &[Point<T, N>],
n: usize,
initial_index: usize,
distance_fn: F,
) -> HashSet<usize>
where
T: FloatNumber,
F: Fn(usize, &Point<T, N>, &Point<T, N>) -> T,
{
let mut selected = HashSet::with_capacity(n);
selected.insert(initial_index);
let mut distances = vec![T::infinity(); points.len()];
let initial_point = &points[initial_index];
update_distances(
points,
&mut distances,
&selected,
initial_point,
&distance_fn,
);
while selected.len() < n {
let farthest_index = find_farthest_index(&distances, &selected);
if !selected.insert(farthest_index) {
break;
}
let farthest_point = &points[farthest_index];
update_distances(
points,
&mut distances,
&selected,
farthest_point,
&distance_fn,
);
}
selected
}
#[inline]
#[must_use]
fn find_farthest_index<T>(distances: &[T], selected: &HashSet<usize>) -> usize
where
T: FloatNumber,
{
distances
.iter()
.enumerate()
.filter(|(index, _)| !selected.contains(index))
.max_by(|(_, distance1), (_, distance2)| {
distance1.partial_cmp(distance2).unwrap_or(Ordering::Equal)
})
.map(|(index, _)| index)
.unwrap_or(0)
}
#[inline]
fn update_distances<T, const N: usize, F>(
points: &[Point<T, N>],
distances: &mut [T],
selected: &HashSet<usize>,
farthest_point: &Point<T, N>,
distance_fn: &F,
) where
T: FloatNumber,
F: Fn(usize, &Point<T, N>, &Point<T, N>) -> T,
{
for (index, point) in points.iter().enumerate() {
if selected.contains(&index) {
distances[index] = T::zero();
continue;
}
let distance = distance_fn(index, point, farthest_point);
distances[index] = distances[index].min(distance);
}
}
#[must_use]
fn sample_diversity<T, const N: usize>(
points: &[Point<T, N>],
scores: &[T],
weight: T,
n: usize,
) -> HashSet<usize>
where
T: FloatNumber,
{
let mut selected = HashSet::with_capacity(n);
let score_rankings = sort_scores_descending(scores);
let best_index = score_rankings.indices[0];
let mut best_point = &points[best_index];
selected.insert(best_index);
let mut similarities = vec![T::max_value(); points.len()];
while selected.len() < n {
for (index, point) in points.iter().enumerate() {
if selected.contains(&index) {
similarities[index] = T::zero();
continue;
}
let similarity = DistanceMetric::SquaredEuclidean.measure(point, best_point);
similarities[index] = similarities[index].min(similarity);
}
let dissimilarity_rankings = sort_scores_descending(&similarities);
let best_index = find_best_index(
&score_rankings.rankings,
&dissimilarity_rankings.rankings,
&selected,
weight,
);
match best_index {
Some(index) => {
selected.insert(index);
best_point = &points[index];
}
None => break,
}
}
selected
}
struct RankedScores {
rankings: Vec<usize>,
indices: Vec<usize>,
}
#[must_use]
fn sort_scores_descending<T>(scores: &[T]) -> RankedScores
where
T: PartialOrd,
{
let mut indices: Vec<usize> = (0..scores.len()).collect();
indices.sort_by(|&index1, &index2| {
scores[index2]
.partial_cmp(&scores[index1])
.unwrap_or(Ordering::Equal)
});
let rankings =
indices
.iter()
.enumerate()
.fold(vec![0; scores.len()], |mut acc, (rank, &index)| {
acc[index] = rank;
acc
});
RankedScores { rankings, indices }
}
#[must_use]
fn find_best_index<T>(
score_rankings: &[usize],
dissimilarity_rankings: &[usize],
selected: &HashSet<usize>,
weight: T,
) -> Option<usize>
where
T: FloatNumber,
{
const RANK_OFFSET: usize = 1;
let (best_index, _) = score_rankings.iter().enumerate().fold(
(None, T::max_value()),
|(best_index, best_score), (index, &score_rank)| {
if selected.contains(&index) {
return (best_index, best_score);
}
let dissimilarity_rank = dissimilarity_rankings[index];
let combined_score = T::from_usize(score_rank + RANK_OFFSET) * (T::one() - weight)
+ T::from_usize(dissimilarity_rank + RANK_OFFSET) * weight;
if combined_score < best_score {
(Some(index), combined_score)
} else {
(best_index, best_score)
}
},
);
best_index
}
#[cfg(test)]
#[cfg_attr(coverage, coverage(off))]
mod tests {
use rstest::rstest;
use super::*;
#[must_use]
fn sample_points() -> Vec<Point<f32, 2>> {
vec![
[0.0, 0.0], [0.1, 0.1], [0.1, 0.2], [0.2, 0.2], [0.2, 0.4], [0.3, 0.5], [0.1, 0.0], [0.0, 0.1], [0.0, 0.2], ]
}
#[must_use]
fn empty_points() -> Vec<Point<f32, 2>> {
vec![]
}
#[rstest]
#[case(0, vec ! [])]
#[case(1, vec ! [0])]
#[case(3, vec ! [0, 3, 5])]
#[case(5, vec ! [0, 1, 3, 5, 8])]
#[case(9, vec ! [0, 1, 2, 3, 4, 5, 6, 7, 8])]
#[case(10, vec ! [0, 1, 2, 3, 4, 5, 6, 7, 8])]
fn test_sample_farthest_point_sampling(#[case] n: usize, #[case] expected: Vec<usize>) {
let points = sample_points();
let sampled = SamplingStrategy::Farthest.sample(&points, n);
assert_eq!(sampled, expected.into_iter().collect());
}
#[test]
fn test_sample_farthest_point_sampling_empty() {
let points = empty_points();
let sampled = SamplingStrategy::Farthest.sample(&points, 2);
assert!(
sampled.is_empty(),
"Sampling from empty points should return empty set"
);
}
#[rstest]
#[case(0, vec ! [])]
#[case(1, vec ! [8])]
#[case(3, vec ! [5, 6, 8])]
#[case(5, vec ! [3, 5, 6, 7, 8])]
#[case(9, vec ! [0, 1, 2, 3, 4, 5, 6, 7, 8])]
#[case(10, vec ! [0, 1, 2, 3, 4, 5, 6, 7, 8])]
fn test_sample_weighted_farthest_point_sampling(
#[case] n: usize,
#[case] expected: Vec<usize>,
) {
let weights = vec![1.0, 1.0, 2.0, 3.0, 5.0, 8.0, 13.0, 21.0, 34.0];
let sampling = SamplingStrategy::WeightedFarthest(weights);
let points = sample_points();
let actual = sampling.sample(&points, n);
assert_eq!(actual, expected.into_iter().collect());
}
#[test]
fn test_sample_weighted_farthest_point_sampling_empty() {
let weights = vec![];
let sampling = SamplingStrategy::WeightedFarthest(weights);
let points = empty_points();
let actual = sampling.sample(&points, 2);
assert!(
actual.is_empty(),
"Sampling from empty points should return empty set"
);
}
#[cfg(debug_assertions)]
#[test]
#[should_panic(expected = "The number of points and weights must be equal.")]
fn test_sample_weighted_farthest_point_sampling_invalid() {
let weights = vec![1.0, 2.0];
let sampling = SamplingStrategy::WeightedFarthest(weights);
let points = sample_points();
let _ = sampling.sample(&points, 2);
}
#[rstest]
#[case(0, vec ! [])]
#[case(1, vec ! [8])]
#[case(3, vec ! [5, 6, 8])]
#[case(5, vec ! [3, 4, 5, 6, 8])]
#[case(9, vec ! [0, 1, 2, 3, 4, 5, 6, 7, 8])]
#[case(10, vec ! [0, 1, 2, 3, 4, 5, 6, 7, 8])]
fn test_sample_diversity_sampling(#[case] n: usize, #[case] expected: Vec<usize>) {
let scores = vec![1.0, 1.0, 2.0, 3.0, 5.0, 8.0, 13.0, 21.0, 34.0];
let weight = 0.5;
let sampling = SamplingStrategy::Diversity(weight, scores);
let points = sample_points();
let actual = sampling.sample(&points, n);
assert_eq!(actual, expected.into_iter().collect());
}
#[test]
fn test_sample_diversity_sampling_empty() {
let points = empty_points();
let scores: Vec<f32> = vec![];
let weight = 0.5;
let sampling = SamplingStrategy::Diversity(weight, scores);
let actual = sampling.sample(&points, 3);
assert!(actual.is_empty());
}
#[cfg(debug_assertions)]
#[test]
#[should_panic]
fn test_sample_diversity_sampling_invalid() {
let weight = 0.5;
let scores = vec![1.0, 2.0, 3.0];
let sampling = SamplingStrategy::Diversity(weight, scores);
let points = sample_points();
let _ = sampling.sample(&points, 2);
}
#[test]
fn test_sort_scores_descending() {
let scores = vec![3.0, 1.0, 2.0, 4.0, 5.0];
let ranked = sort_scores_descending(&scores);
assert_eq!(ranked.rankings, vec![2, 4, 3, 1, 0]);
assert_eq!(ranked.indices, vec![4, 3, 0, 2, 1]);
}
#[test]
fn test_find_best_index() {
let score_rankings = vec![0, 1, 2, 3, 4];
let dissimilarity_rankings = vec![4, 3, 2, 1, 0];
let selected = HashSet::from([0, 1]);
let weight = 0.5;
let best_index =
find_best_index(&score_rankings, &dissimilarity_rankings, &selected, weight);
assert_eq!(best_index, Some(2));
}
}