use crate::distance::{DistanceMetric, distance};
pub fn batch_distances(query: &[f32], candidates: &[&[f32]], metric: DistanceMetric) -> Vec<f32> {
candidates
.iter()
.map(|candidate| distance(query, candidate, metric))
.collect()
}
pub fn is_diverse_batched(
candidate_vec: &[f32],
candidate_dist_to_query: f32,
selected_vecs: &[&[f32]],
metric: DistanceMetric,
) -> bool {
for selected in selected_vecs {
let dist_to_selected = distance(candidate_vec, selected, metric);
if candidate_dist_to_query > dist_to_selected {
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn batch_distances_correctness() {
let query = [1.0, 0.0, 0.0];
let c1 = [0.0, 1.0, 0.0];
let c2 = [1.0, 0.0, 0.0];
let c3 = [0.0, 0.0, 1.0];
let dists = batch_distances(&query, &[&c1, &c2, &c3], DistanceMetric::L2);
assert_eq!(dists.len(), 3);
assert_eq!(dists[1], 0.0);
assert_eq!(dists[0], dists[2]);
}
#[test]
fn diversity_check() {
let candidate = [1.0, 0.0];
let selected1 = [0.9, 0.1];
assert!(!is_diverse_batched(
&candidate,
0.5,
&[&selected1],
DistanceMetric::L2,
));
assert!(is_diverse_batched(
&candidate,
0.01,
&[&selected1],
DistanceMetric::L2,
));
}
}