use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, PartialEq)]
pub struct MergedResult {
pub id: u32,
pub distance: f32,
pub from_delta: bool,
}
pub fn merge_results(
main: Vec<(u32, f32)>,
delta: Vec<(u32, f32)>,
tombstones: &HashSet<u32>,
k: usize,
) -> Vec<MergedResult> {
if k == 0 {
return Vec::new();
}
let mut by_id: HashMap<u32, MergedResult> = HashMap::new();
for (id, dist) in delta {
if tombstones.contains(&id) {
continue;
}
by_id.insert(
id,
MergedResult {
id,
distance: dist,
from_delta: true,
},
);
}
for (id, dist) in main {
if tombstones.contains(&id) {
continue;
}
by_id.entry(id).or_insert(MergedResult {
id,
distance: dist,
from_delta: false,
});
}
let mut merged: Vec<MergedResult> = by_id.into_values().collect();
if k < merged.len() {
merged.select_nth_unstable_by(k, |a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
merged.truncate(k);
}
merged.sort_unstable_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
merged
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn delta_beats_main_wins_collision() {
let main = vec![(1, 1.0), (2, 2.0)];
let delta = vec![(3, 0.5), (1, 0.9)]; let tombstones = HashSet::new();
let results = merge_results(main, delta, &tombstones, 3);
assert_eq!(results.len(), 3);
assert_eq!(results[0].id, 3);
assert!(results[0].from_delta);
assert_eq!(results[1].id, 1);
assert!(results[1].from_delta);
assert!((results[1].distance - 0.9).abs() < 1e-6);
assert_eq!(results[2].id, 2);
assert!(!results[2].from_delta);
}
#[test]
fn top_two_delta_first() {
let main = vec![(1, 1.0f32), (2, 2.0f32)];
let delta = vec![(3, 0.5f32)];
let tombstones = HashSet::new();
let results = merge_results(main, delta, &tombstones, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, 3);
assert_eq!(results[1].id, 1);
}
#[test]
fn tombstone_excludes_from_both() {
let main = vec![(1, 1.0f32), (2, 2.0f32)];
let delta = vec![(3, 0.5f32)];
let mut tombstones = HashSet::new();
tombstones.insert(2u32);
let results = merge_results(main, delta, &tombstones, 10);
assert!(results.iter().all(|r| r.id != 2));
assert_eq!(results.len(), 2);
}
#[test]
fn empty_inputs_returns_empty() {
let results = merge_results(vec![], vec![], &HashSet::new(), 10);
assert!(results.is_empty());
}
#[test]
fn k_zero_returns_empty() {
let main = vec![(1, 0.5f32)];
let results = merge_results(main, vec![], &HashSet::new(), 0);
assert!(results.is_empty());
}
}