use crate::DeterministicScore;
use khive_types::DistanceMetric;
#[inline]
pub fn score_from_distance(dist: f32, metric: DistanceMetric) -> DeterministicScore {
let d = if dist.is_nan() { 0.0 } else { dist } as f64;
let similarity = match metric {
DistanceMetric::Cosine => 1.0 - d,
DistanceMetric::Dot => -d,
DistanceMetric::L2 => 1.0 / (1.0 + d.max(0.0)),
_ => 1.0 - d,
};
DeterministicScore::from_f64(similarity)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cosine_basic() {
let s = score_from_distance(0.2, DistanceMetric::Cosine);
assert!((s.to_f64() - 0.8).abs() < 1e-6, "got {}", s.to_f64());
}
#[test]
fn dot_basic() {
let s = score_from_distance(-5.0, DistanceMetric::Dot);
assert!((s.to_f64() - 5.0).abs() < 1e-6, "got {}", s.to_f64());
}
#[test]
fn l2_basic() {
let s = score_from_distance(1.0, DistanceMetric::L2);
assert!((s.to_f64() - 0.5).abs() < 1e-6, "got {}", s.to_f64());
}
#[test]
fn l2_zero_distance() {
let s = score_from_distance(0.0, DistanceMetric::L2);
assert!((s.to_f64() - 1.0).abs() < 1e-6, "got {}", s.to_f64());
}
#[test]
fn l2_large_distance() {
let s = score_from_distance(1_000_000.0_f32, DistanceMetric::L2);
assert!(s.to_f64() < 1e-5, "got {}", s.to_f64());
assert!(s.to_f64() >= 0.0, "similarity must be non-negative");
}
#[test]
fn cosine_zero_distance() {
let s = score_from_distance(0.0, DistanceMetric::Cosine);
assert!((s.to_f64() - 1.0).abs() < 1e-6, "got {}", s.to_f64());
}
#[test]
fn cosine_max_distance() {
let s = score_from_distance(2.0, DistanceMetric::Cosine);
assert!((s.to_f64() - (-1.0)).abs() < 1e-6, "got {}", s.to_f64());
}
#[test]
fn nan_maps_to_zero_distance() {
let s = score_from_distance(f32::NAN, DistanceMetric::Cosine);
assert!(
(s.to_f64() - 1.0).abs() < 1e-6,
"NaN should map to similarity 1.0, got {}",
s.to_f64()
);
}
#[test]
fn parity_with_hnsw_local_impl() {
fn reference(dist: f32, metric: DistanceMetric) -> f64 {
let d = if dist.is_nan() { 0.0 } else { dist } as f64;
match metric {
DistanceMetric::Cosine => 1.0 - d,
DistanceMetric::Dot => -d,
DistanceMetric::L2 => 1.0 / (1.0 + d.max(0.0)),
_ => 1.0 - d,
}
}
let cases: &[(f32, DistanceMetric)] = &[
(0.0, DistanceMetric::Cosine),
(0.2, DistanceMetric::Cosine),
(1.0, DistanceMetric::Cosine),
(2.0, DistanceMetric::Cosine),
(f32::NAN, DistanceMetric::Cosine),
(-5.0, DistanceMetric::Dot),
(0.0, DistanceMetric::Dot),
(3.0, DistanceMetric::Dot),
(0.0, DistanceMetric::L2),
(1.0, DistanceMetric::L2),
(4.0, DistanceMetric::L2),
(1_000_000.0, DistanceMetric::L2),
];
for &(dist, metric) in cases {
let expected = DeterministicScore::from_f64(reference(dist, metric));
let got = score_from_distance(dist, metric);
assert_eq!(
got,
expected,
"parity failure for dist={dist:?} metric={metric:?}: \
expected raw={} got raw={}",
expected.to_raw(),
got.to_raw()
);
}
}
}