use std::hash::Hash;
pub fn from_distances<I: Clone + Eq + Hash>(results: &[(I, f32)]) -> Vec<(I, f32)> {
let mut scored: Vec<(I, f32)> = results
.iter()
.map(|(id, dist)| (id.clone(), 1.0 / (1.0 + dist)))
.collect();
scored.sort_by(|a, b| b.1.total_cmp(&a.1));
scored
}
pub fn from_distances_mapped<I, O, F>(results: &[(I, f32)], map_id: F) -> Vec<(O, f32)>
where
I: Clone,
O: Clone + Eq + Hash,
F: Fn(&I) -> O,
{
let mut scored: Vec<(O, f32)> = results
.iter()
.map(|(id, dist)| (map_id(id), 1.0 / (1.0 + dist)))
.collect();
scored.sort_by(|a, b| b.1.total_cmp(&a.1));
scored
}
pub fn from_similarities<I: Clone + Eq + Hash>(results: &[(I, f32)]) -> Vec<(I, f32)> {
let mut scored: Vec<(I, f32)> = results.to_vec();
scored.sort_by(|a, b| b.1.total_cmp(&a.1));
scored
}
pub fn from_similarities_mapped<I, O, F>(results: &[(I, f32)], map_id: F) -> Vec<(O, f32)>
where
I: Clone,
O: Clone + Eq + Hash,
F: Fn(&I) -> O,
{
let mut scored: Vec<(O, f32)> = results.iter().map(|(id, s)| (map_id(id), *s)).collect();
scored.sort_by(|a, b| b.1.total_cmp(&a.1));
scored
}
pub fn from_inner_product<I: Clone + Eq + Hash>(results: &[(I, f32)]) -> Vec<(I, f32)> {
from_similarities(results)
}
pub fn from_logits<I: Clone + Eq + Hash>(results: &[(I, f32)]) -> Vec<(I, f32)> {
let mut scored: Vec<(I, f32)> = results
.iter()
.map(|(id, logit)| (id.clone(), 1.0 / (1.0 + (-logit).exp())))
.collect();
scored.sort_by(|a, b| b.1.total_cmp(&a.1));
scored
}
pub fn from_logits_mapped<I, O, F>(results: &[(I, f32)], map_id: F) -> Vec<(O, f32)>
where
I: Clone,
O: Clone + Eq + Hash,
F: Fn(&I) -> O,
{
let mut scored: Vec<(O, f32)> = results
.iter()
.map(|(id, logit)| (map_id(id), 1.0 / (1.0 + (-logit).exp())))
.collect();
scored.sort_by(|a, b| b.1.total_cmp(&a.1));
scored
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn distances_to_scores() {
let results = vec![(1u32, 0.0), (2, 1.0), (3, 9.0)];
let ranked = from_distances(&results);
assert_eq!(ranked[0].0, 1);
assert!((ranked[0].1 - 1.0).abs() < 1e-6);
assert_eq!(ranked[1].0, 2);
assert!((ranked[1].1 - 0.5).abs() < 1e-6);
assert_eq!(ranked[2].0, 3);
assert!((ranked[2].1 - 0.1).abs() < 1e-6);
}
#[test]
fn distances_mapped() {
let names = ["a", "b", "c", "d"];
let results = vec![(2u32, 0.5), (0, 0.1)];
let ranked = from_distances_mapped(&results, |id| names[*id as usize]);
assert_eq!(ranked[0].0, "a"); assert_eq!(ranked[1].0, "c"); }
#[test]
fn similarities_passthrough() {
let results = vec![("d1", 0.3), ("d2", 0.9), ("d3", 0.6)];
let ranked = from_similarities(&results);
assert_eq!(ranked[0].0, "d2");
assert_eq!(ranked[1].0, "d3");
assert_eq!(ranked[2].0, "d1");
}
#[test]
fn logits_conversion() {
let results = vec![("d1", 0.0), ("d2", 5.0), ("d3", -5.0)];
let ranked = from_logits(&results);
assert_eq!(ranked[0].0, "d2");
assert!(ranked[0].1 > 0.99);
assert_eq!(ranked[1].0, "d1");
assert!((ranked[1].1 - 0.5).abs() < 1e-6);
assert_eq!(ranked[2].0, "d3");
assert!(ranked[2].1 < 0.01);
}
#[test]
fn empty_inputs() {
let empty: Vec<(u32, f32)> = vec![];
assert!(from_distances(&empty).is_empty());
assert!(from_similarities(&empty).is_empty());
assert!(from_logits(&empty).is_empty());
}
#[test]
fn adapter_then_fuse() {
let bm25 = vec![("d1", 12.0), ("d2", 10.0), ("d3", 8.0)];
let ann_distances = vec![("d2", 0.1), ("d4", 0.3), ("d1", 0.9)];
let ann_scores = from_distances(&ann_distances);
let fused = crate::rrf(&bm25, &ann_scores);
assert!(!fused.is_empty());
let d2_pos = fused.iter().position(|(id, _)| *id == "d2").unwrap();
assert!(d2_pos < 2, "d2 should rank high (in both lists)");
}
}