pub struct TopK {
k: usize,
distances: Vec<f32>,
ids: Vec<u32>,
count: usize,
}
impl TopK {
#[must_use]
pub fn new(k: usize) -> Self {
assert!(k > 0, "innr::TopK: k must be >= 1");
Self {
k,
distances: Vec::with_capacity(k),
ids: Vec::with_capacity(k),
count: 0,
}
}
#[inline]
#[must_use]
pub fn threshold(&self) -> f32 {
if self.count < self.k {
f32::INFINITY
} else {
self.distances[0]
}
}
#[inline]
pub fn insert(&mut self, id: u32, distance: f32) {
if self.count < self.k {
self.insert_sorted(id, distance);
self.count += 1;
} else if distance < self.distances[0] {
self.distances.copy_within(1.., 0);
self.ids.copy_within(1.., 0);
let pos = self.find_insert_pos(distance, self.k - 1);
self.distances.copy_within(pos..self.k - 1, pos + 1);
self.ids.copy_within(pos..self.k - 1, pos + 1);
self.distances[pos] = distance;
self.ids[pos] = id;
}
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.count
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.count == 0
}
#[must_use]
pub fn into_sorted(mut self) -> Vec<(u32, f32)> {
self.distances.reverse();
self.ids.reverse();
self.ids.into_iter().zip(self.distances).collect()
}
#[inline]
fn insert_sorted(&mut self, id: u32, distance: f32) {
let pos = self.find_insert_pos(distance, self.count);
self.distances.push(distance);
self.ids.push(id);
let len = self.distances.len();
self.distances.copy_within(pos..len - 1, pos + 1);
self.ids.copy_within(pos..len - 1, pos + 1);
self.distances[pos] = distance;
self.ids[pos] = id;
}
#[inline]
fn find_insert_pos(&self, distance: f32, len: usize) -> usize {
let slice = &self.distances[..len];
match slice.binary_search_by(|&d| {
d.partial_cmp(&distance)
.unwrap_or(std::cmp::Ordering::Equal)
.reverse()
}) {
Ok(i) | Err(i) => i,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_top3() {
let mut top = TopK::new(3);
top.insert(0, 1.5);
top.insert(1, 0.3);
top.insert(2, 2.0);
top.insert(3, 0.8);
top.insert(4, 5.0);
assert_eq!(top.len(), 3);
let results = top.into_sorted();
assert_eq!(results.len(), 3);
assert_eq!(results[0], (1, 0.3));
assert_eq!(results[1], (3, 0.8));
assert_eq!(results[2], (0, 1.5));
}
#[test]
fn threshold_tracking() {
let mut top = TopK::new(3);
assert_eq!(top.threshold(), f32::INFINITY);
top.insert(0, 1.0);
assert_eq!(top.threshold(), f32::INFINITY);
top.insert(1, 2.0);
assert_eq!(top.threshold(), f32::INFINITY);
top.insert(2, 3.0);
assert_eq!(top.threshold(), 3.0);
top.insert(3, 1.5); assert_eq!(top.threshold(), 2.0);
top.insert(4, 0.5); assert_eq!(top.threshold(), 1.5);
top.insert(5, 10.0); assert_eq!(top.threshold(), 1.5);
}
#[test]
fn duplicate_distances() {
let mut top = TopK::new(3);
top.insert(0, 1.0);
top.insert(1, 1.0);
top.insert(2, 1.0);
top.insert(3, 1.0); assert_eq!(top.len(), 3);
let results = top.into_sorted();
assert_eq!(results.len(), 3);
for (_, d) in &results {
assert_eq!(*d, 1.0);
}
}
#[test]
fn k1_edge_case() {
let mut top = TopK::new(1);
assert_eq!(top.threshold(), f32::INFINITY);
top.insert(0, 5.0);
assert_eq!(top.threshold(), 5.0);
top.insert(1, 3.0);
assert_eq!(top.threshold(), 3.0);
top.insert(2, 10.0);
assert_eq!(top.threshold(), 3.0);
top.insert(3, 1.0);
assert_eq!(top.threshold(), 1.0);
let results = top.into_sorted();
assert_eq!(results, vec![(3, 1.0)]);
}
#[test]
fn large_n_k10() {
let k = 10;
let mut top = TopK::new(k);
for i in 0u32..10_000 {
top.insert(i, i as f32);
}
assert_eq!(top.len(), k);
let results = top.into_sorted();
assert_eq!(results.len(), k);
for (rank, (id, dist)) in results.iter().enumerate() {
assert_eq!(*id, rank as u32);
assert!((*dist - rank as f32).abs() < 1e-6);
}
}
#[test]
fn sorted_output_ascending() {
let mut top = TopK::new(5);
for i in (0u32..5).rev() {
top.insert(i, i as f32);
}
let results = top.into_sorted();
for i in 0..results.len() - 1 {
assert!(
results[i].1 <= results[i + 1].1,
"not ascending at index {i}"
);
}
}
#[test]
fn is_empty_and_len() {
let mut top = TopK::new(4);
assert!(top.is_empty());
assert_eq!(top.len(), 0);
top.insert(0, 1.0);
assert!(!top.is_empty());
assert_eq!(top.len(), 1);
top.insert(1, 2.0);
top.insert(2, 3.0);
top.insert(3, 4.0);
assert_eq!(top.len(), 4);
top.insert(4, 5.0); assert_eq!(top.len(), 4);
}
#[test]
fn insert_in_sorted_order() {
let mut top = TopK::new(4);
top.insert(0, 1.0);
top.insert(1, 2.0);
top.insert(2, 3.0);
top.insert(3, 4.0);
top.insert(4, 0.5); let results = top.into_sorted();
assert_eq!(results[0], (4, 0.5));
assert_eq!(results[3], (2, 3.0));
}
}