use std::cmp::Ordering;
use std::collections::BinaryHeap;
#[derive(Debug, Clone, Copy)]
pub(crate) struct Scored {
pub(crate) dist: f32,
pub(crate) seq: u64,
pub(crate) idx: usize,
}
impl Scored {
fn cmp_key(&self, other: &Self) -> Ordering {
self.dist
.total_cmp(&other.dist)
.then(self.seq.cmp(&other.seq))
}
}
impl PartialEq for Scored {
fn eq(&self, other: &Self) -> bool {
self.cmp_key(other) == Ordering::Equal
}
}
impl Eq for Scored {}
impl PartialOrd for Scored {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Scored {
fn cmp(&self, other: &Self) -> Ordering {
self.cmp_key(other)
}
}
pub(crate) fn select_topk_indices(distances: &[f32], seqs: &[u64], k: usize) -> Vec<usize> {
debug_assert_eq!(distances.len(), seqs.len(), "distances and seqs must align");
if k == 0 || distances.is_empty() {
return Vec::new();
}
let cap = k.min(distances.len());
let mut heap: BinaryHeap<Scored> = BinaryHeap::with_capacity(cap);
for (idx, (&dist, &seq)) in distances.iter().zip(seqs.iter()).enumerate() {
let entry = Scored { dist, seq, idx };
if heap.len() < cap {
heap.push(entry);
} else if heap.peek().is_some_and(|worst| entry < *worst) {
let _evicted = heap.pop();
heap.push(entry);
}
}
heap.into_sorted_vec().into_iter().map(|s| s.idx).collect()
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
fn seqs_in_order(n: usize) -> Vec<u64> {
(0..n as u64).collect()
}
#[test]
fn select_topk_zero_k_returns_empty() {
let out = select_topk_indices(&[1.0, 2.0, 3.0], &seqs_in_order(3), 0);
assert!(out.is_empty());
}
#[test]
fn select_topk_empty_distances_returns_empty() {
let out = select_topk_indices(&[], &[], 5);
assert!(out.is_empty());
}
#[test]
fn select_topk_k_greater_than_n_returns_all_sorted() {
let out = select_topk_indices(&[3.0, 1.0, 2.0], &seqs_in_order(3), 10);
assert_eq!(out, vec![1, 2, 0]);
}
#[test]
fn select_topk_returns_best_first() {
let out = select_topk_indices(&[5.0, 1.0, 4.0, 2.0, 3.0], &seqs_in_order(5), 3);
assert_eq!(out, vec![1, 3, 4]);
}
#[test]
fn select_topk_breaks_ties_by_lower_seq() {
let out = select_topk_indices(&[1.0, 1.0, 1.0, 0.5], &[0, 1, 2, 3], 3);
assert_eq!(out, vec![3, 0, 1]);
}
#[test]
fn select_topk_tiebreaker_is_seq_not_idx() {
let out = select_topk_indices(&[1.0, 1.0, 1.0, 1.0], &[1, 3, 0, 2], 4);
assert_eq!(out, vec![2, 0, 3, 1]);
}
#[test]
fn select_topk_handles_nan_via_total_cmp() {
let out = select_topk_indices(&[f32::NAN, 1.0, 2.0], &seqs_in_order(3), 3);
assert_eq!(out, vec![1, 2, 0]);
}
}