use nodedb_types::vector_distance::DistanceMetric;
use crate::distance::scalar::scalar_distance;
fn dist_to_sim(d: f32, metric: DistanceMetric) -> f32 {
match metric {
DistanceMetric::Cosine => 1.0 - d,
_ => -d,
}
}
pub fn maxsim(query: &[Vec<f32>], doc: &[Vec<f32>], metric: DistanceMetric) -> f32 {
if query.is_empty() || doc.is_empty() {
return 0.0;
}
query
.iter()
.map(|q| {
doc.iter()
.map(|d| dist_to_sim(scalar_distance(q, d, metric), metric))
.fold(f32::NEG_INFINITY, f32::max)
})
.sum()
}
pub fn budgeted_maxsim(
query: &[Vec<f32>],
doc: &[Vec<f32>],
budget: u8,
metric: DistanceMetric,
) -> f32 {
let effective = (budget as usize).min(query.len());
maxsim(&query[..effective], doc, metric)
}
#[cfg(test)]
mod tests {
use super::*;
fn unit_vec(dim: usize, pos: usize) -> Vec<f32> {
let mut v = vec![0.0f32; dim];
v[pos] = 1.0;
v
}
#[test]
fn maxsim_identical_query_doc_l2() {
let q = vec![unit_vec(4, 0), unit_vec(4, 1), unit_vec(4, 2)];
let d = q.clone();
let score = maxsim(&q, &d, DistanceMetric::L2);
assert!((score - 0.0).abs() < 1e-6, "score={score}");
}
#[test]
fn maxsim_identical_query_doc_cosine() {
let q = vec![unit_vec(4, 0), unit_vec(4, 1), unit_vec(4, 2)];
let d = q.clone();
let score = maxsim(&q, &d, DistanceMetric::Cosine);
assert!((score - 3.0).abs() < 1e-5, "score={score}");
}
#[test]
fn maxsim_empty_query_returns_zero() {
let d = vec![unit_vec(4, 0)];
assert_eq!(maxsim(&[], &d, DistanceMetric::Cosine), 0.0);
}
#[test]
fn maxsim_empty_doc_returns_zero() {
let q = vec![unit_vec(4, 0)];
assert_eq!(maxsim(&q, &[], DistanceMetric::Cosine), 0.0);
}
#[test]
fn budgeted_vs_full_differ_when_extra_tokens_are_relevant() {
let q = vec![
unit_vec(4, 0),
unit_vec(4, 1),
unit_vec(4, 2),
unit_vec(4, 3),
];
let d = vec![unit_vec(4, 0), unit_vec(4, 1)];
let full = maxsim(&q, &d, DistanceMetric::Cosine);
let budgeted = budgeted_maxsim(&q, &d, 2, DistanceMetric::Cosine);
let d2 = vec![unit_vec(4, 0), unit_vec(4, 1), unit_vec(4, 2)];
let full2 = maxsim(&q, &d2, DistanceMetric::Cosine);
let budgeted2 = budgeted_maxsim(&q, &d2, 2, DistanceMetric::Cosine);
assert!(
(full - budgeted).abs() < 1e-5,
"first case should be equal: full={full} budgeted={budgeted}"
);
assert!(
full2 > budgeted2,
"full2 should exceed budgeted2: full2={full2} budgeted2={budgeted2}"
);
}
#[test]
fn budgeted_zero_budget_returns_zero() {
let q = vec![unit_vec(4, 0)];
let d = vec![unit_vec(4, 0)];
assert_eq!(budgeted_maxsim(&q, &d, 0, DistanceMetric::Cosine), 0.0);
}
#[test]
fn budgeted_exceeds_query_length_equals_full() {
let q = vec![unit_vec(4, 0), unit_vec(4, 1)];
let d = q.clone();
let full = maxsim(&q, &d, DistanceMetric::Cosine);
let budgeted = budgeted_maxsim(&q, &d, 255, DistanceMetric::Cosine);
assert!((full - budgeted).abs() < 1e-6);
}
}