use std::{cmp::Ordering, collections::HashSet};
use crate::{
math::{
sampling::{algorithm::SamplingAlgorithm, error::SamplingError},
DistanceMetric,
Point,
},
FloatNumber,
};
#[derive(Debug, PartialEq)]
pub struct DiversitySampling<T>
where
T: FloatNumber,
{
diversity_factor: T,
ranked: RankedScores<T>,
metric: DistanceMetric,
}
impl<T> DiversitySampling<T>
where
T: FloatNumber,
{
pub fn new(
diversity_factor: T,
weights: Vec<T>,
metric: DistanceMetric,
) -> Result<Self, SamplingError> {
if diversity_factor < T::zero() || diversity_factor > T::one() {
return Err(SamplingError::InvalidDiversity);
}
if weights.is_empty() {
return Err(SamplingError::EmptyWeights);
}
Ok(Self {
diversity_factor,
ranked: RankedScores::new(weights),
metric,
})
}
}
impl<T> SamplingAlgorithm<T> for DiversitySampling<T>
where
T: FloatNumber,
{
fn sample<const N: usize>(
&self,
points: &[Point<T, N>],
num_samples: usize,
) -> Result<HashSet<usize>, SamplingError> {
if points.is_empty() {
return Err(SamplingError::EmptyPoints);
}
if self.ranked.len() != points.len() {
return Err(SamplingError::WeightsLengthMismatch {
points_len: points.len(),
weights_len: self.ranked.len(),
});
}
if num_samples == 0 {
return Ok(HashSet::new());
}
if points.len() <= num_samples {
return Ok((0..points.len()).collect());
}
let best_index = self.ranked.indices[0];
let mut best_point = points[best_index];
let mut selected = HashSet::with_capacity(num_samples);
selected.insert(best_index);
let mut similarities = vec![T::max_value(); points.len()];
while selected.len() < num_samples {
for (index, point) in points.iter().enumerate() {
if selected.contains(&index) {
similarities[index] = T::zero();
} else {
let similarity = self.metric.measure(&best_point, point);
similarities[index] = similarities[index].min(similarity);
}
}
let dissimilarity_rankings = RankedScores::new(similarities.clone());
let best_index = find_best_index(
&self.ranked.rankings,
&dissimilarity_rankings.rankings,
&selected,
self.diversity_factor,
);
match best_index {
Some(index) => {
selected.insert(index);
best_point = points[index];
}
None => break,
}
}
Ok(selected)
}
}
#[derive(Debug, PartialEq)]
struct RankedScores<T>
where
T: FloatNumber,
{
scores: Vec<T>,
rankings: Vec<usize>,
indices: Vec<usize>,
}
impl<T> RankedScores<T>
where
T: FloatNumber,
{
#[must_use]
fn new(scores: Vec<T>) -> Self {
let mut indices: Vec<usize> = (0..scores.len()).collect();
indices.sort_by(|&index1, &index2| {
scores[index2]
.partial_cmp(&scores[index1])
.unwrap_or(Ordering::Equal)
});
let rankings =
indices
.iter()
.enumerate()
.fold(vec![0; scores.len()], |mut acc, (rank, &index)| {
acc[index] = rank;
acc
});
RankedScores {
scores,
rankings,
indices,
}
}
pub fn len(&self) -> usize {
self.scores.len()
}
}
#[inline]
#[must_use]
fn find_best_index<T>(
score_rankings: &[usize],
dissimilarity_rankings: &[usize],
selected: &HashSet<usize>,
weight: T,
) -> Option<usize>
where
T: FloatNumber,
{
const RANK_OFFSET: usize = 1;
let (best_index, _) = score_rankings.iter().enumerate().fold(
(None, T::max_value()),
|(best_index, best_score), (index, &score_rank)| {
if selected.contains(&index) {
return (best_index, best_score);
}
let dissimilarity_rank = dissimilarity_rankings[index];
let combined_score = T::from_usize(score_rank + RANK_OFFSET) * (T::one() - weight)
+ T::from_usize(dissimilarity_rank + RANK_OFFSET) * weight;
if combined_score < best_score {
(Some(index), combined_score)
} else {
(best_index, best_score)
}
},
);
best_index
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
#[must_use]
fn sample_points() -> Vec<Point<f32, 2>> {
vec![
[0.0, 0.0], [0.1, 0.1], [0.1, 0.2], [0.2, 0.2], [0.2, 0.4], [0.3, 0.5], [0.1, 0.0], [0.0, 0.1], [0.0, 0.2], ]
}
#[must_use]
fn empty_points() -> Vec<Point<f32, 2>> {
vec![]
}
#[test]
fn test_new() {
let weights = vec![1.0, 2.0, 3.0];
let actual =
DiversitySampling::new(0.5, weights.clone(), DistanceMetric::SquaredEuclidean).unwrap();
assert_eq!(
actual,
DiversitySampling {
diversity_factor: 0.5,
ranked: RankedScores::new(weights),
metric: DistanceMetric::SquaredEuclidean,
}
);
}
#[rstest]
#[case::diversity_lt_0(-0.1)]
#[case::diversity_gt_1(1.1)]
fn test_new_invalid_diversity(#[case] diversity: f32) {
let weights = vec![1.0, 2.0, 3.0];
let actual = DiversitySampling::new(diversity, weights, DistanceMetric::SquaredEuclidean);
assert!(actual.is_err());
assert_eq!(actual.unwrap_err(), SamplingError::InvalidDiversity,);
}
#[test]
fn test_new_empty_weights() {
let weights = vec![];
let actual = DiversitySampling::new(0.5, weights, DistanceMetric::SquaredEuclidean);
assert!(actual.is_err());
assert_eq!(actual.unwrap_err(), SamplingError::EmptyWeights);
}
#[rstest]
#[case(0, vec ! [])]
#[case(1, vec ! [8])]
#[case(3, vec ! [5, 6, 8])]
#[case(5, vec ! [3, 4, 5, 6, 8])]
#[case(9, vec ! [0, 1, 2, 3, 4, 5, 6, 7, 8])]
#[case(10, vec ! [0, 1, 2, 3, 4, 5, 6, 7, 8])]
fn test_sample(#[case] num_samples: usize, #[case] expected: Vec<usize>) {
let weights = vec![1.0, 1.0, 2.0, 3.0, 5.0, 8.0, 13.0, 21.0, 34.0];
let algorithm =
DiversitySampling::new(0.8, weights, DistanceMetric::SquaredEuclidean).unwrap();
let points = sample_points();
let actual = algorithm.sample(&points, num_samples).unwrap();
assert_eq!(actual.len(), expected.len());
assert_eq!(actual, expected.into_iter().collect());
}
#[test]
fn test_sample_empty_points() {
let weights = vec![1.0, 1.0, 2.0, 3.0, 5.0, 8.0, 13.0, 21.0, 34.0];
let algorithm =
DiversitySampling::new(0.8, weights, DistanceMetric::SquaredEuclidean).unwrap();
let points = empty_points();
let actual = algorithm.sample(&points, 3);
assert!(actual.is_err());
assert_eq!(actual.unwrap_err(), SamplingError::EmptyPoints);
}
#[test]
fn test_sample_scores_length_mismatch() {
let weights = vec![1.0, 2.0, 3.0];
let algorithm =
DiversitySampling::new(0.8, weights.clone(), DistanceMetric::SquaredEuclidean).unwrap();
let points = sample_points();
let actual = algorithm.sample(&points, 3);
assert!(actual.is_err());
assert_eq!(
actual.unwrap_err(),
SamplingError::WeightsLengthMismatch {
points_len: points.len(),
weights_len: weights.len(),
}
);
}
}