use crate::distance::DistanceMetric;
fn should_sort_descending(metric: DistanceMetric) -> bool {
metric.higher_is_better()
}
fn filter_by_similarity_gt(metric: DistanceMetric, score: f32, threshold: f32) -> bool {
if metric.higher_is_better() {
score > threshold
} else {
score < threshold
}
}
fn sort_by_similarity(metric: DistanceMetric, scores: &mut [f32]) {
if metric.higher_is_better() {
scores.sort_by(|a, b| b.partial_cmp(a).unwrap());
} else {
scores.sort_by(|a, b| a.partial_cmp(b).unwrap());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_higher_is_better_semantics() {
assert!(DistanceMetric::Cosine.higher_is_better());
assert!(DistanceMetric::DotProduct.higher_is_better());
assert!(DistanceMetric::Jaccard.higher_is_better());
assert!(!DistanceMetric::Euclidean.higher_is_better());
assert!(!DistanceMetric::Hamming.higher_is_better());
}
#[test]
fn test_sort_direction_for_metrics() {
assert!(
should_sort_descending(DistanceMetric::Cosine),
"Cosine should sort DESC"
);
assert!(
!should_sort_descending(DistanceMetric::Euclidean),
"Euclidean should sort ASC"
);
}
#[test]
fn test_threshold_comparison_semantics() {
let high_cosine_score = 0.95; let low_cosine_score = 0.3;
let low_euclidean_dist = 0.2; let high_euclidean_dist = 5.0;
let threshold = 0.5;
assert!(
filter_by_similarity_gt(DistanceMetric::Cosine, high_cosine_score, threshold),
"High cosine score should pass > threshold"
);
assert!(
!filter_by_similarity_gt(DistanceMetric::Cosine, low_cosine_score, threshold),
"Low cosine score should fail > threshold"
);
assert!(
filter_by_similarity_gt(DistanceMetric::Euclidean, low_euclidean_dist, threshold),
"Low euclidean distance (0.2) should pass similarity > 0.5"
);
assert!(
!filter_by_similarity_gt(DistanceMetric::Euclidean, high_euclidean_dist, threshold),
"High euclidean distance (5.0) should fail similarity > 0.5"
);
}
#[test]
fn test_sort_results_by_similarity() {
let mut cosine_scores = vec![0.3, 0.9, 0.5, 0.7];
let mut euclidean_dists = vec![0.3, 0.9, 0.5, 0.7];
sort_by_similarity(DistanceMetric::Cosine, &mut cosine_scores);
sort_by_similarity(DistanceMetric::Euclidean, &mut euclidean_dists);
assert_eq!(cosine_scores, vec![0.9, 0.7, 0.5, 0.3]);
assert_eq!(euclidean_dists, vec![0.3, 0.5, 0.7, 0.9]);
}
}