use std::cmp::Ordering;
use std::collections::BinaryHeap;
pub fn normalize(v: &mut [f32]) {
let norm_sq: f32 = v.iter().map(|x| x * x).sum();
let norm = norm_sq.sqrt();
if norm < 1e-12 || !norm.is_finite() {
return;
}
let inv = 1.0 / norm;
for x in v.iter_mut() {
*x *= inv;
}
}
pub fn dot(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "dot: slice lengths must be equal");
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[derive(Clone, Copy, PartialEq)]
struct OrdF32(f32);
impl Eq for OrdF32 {}
impl PartialOrd for OrdF32 {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OrdF32 {
fn cmp(&self, other: &Self) -> Ordering {
match (self.0.is_nan(), other.0.is_nan()) {
(true, true) => Ordering::Equal,
(true, false) => Ordering::Less, (false, true) => Ordering::Greater, (false, false) => {
self.0.partial_cmp(&other.0).unwrap_or(Ordering::Equal)
}
}
}
}
struct Entry<T> {
score: OrdF32,
item: T,
}
impl<T> PartialEq for Entry<T> {
fn eq(&self, other: &Self) -> bool {
self.score == other.score
}
}
impl<T> Eq for Entry<T> {}
impl<T> PartialOrd for Entry<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T> Ord for Entry<T> {
fn cmp(&self, other: &Self) -> Ordering {
other.score.cmp(&self.score)
}
}
pub struct TopK<T> {
k: usize,
heap: BinaryHeap<Entry<T>>,
}
impl<T> TopK<T> {
pub fn new(k: usize) -> Self {
TopK {
k,
heap: BinaryHeap::with_capacity(k.saturating_add(1)),
}
}
pub fn offer(&mut self, score: f32, item: T) {
if self.k == 0 {
return;
}
let ord_score = OrdF32(score);
if self.heap.len() < self.k {
self.heap.push(Entry {
score: ord_score,
item,
});
} else if let Some(worst) = self.heap.peek() {
if ord_score > worst.score {
self.heap.pop();
self.heap.push(Entry {
score: ord_score,
item,
});
}
}
}
pub fn into_sorted_desc(self) -> Vec<(f32, T)> {
let mut v: Vec<(f32, T)> = self.heap.into_iter().map(|e| (e.score.0, e.item)).collect();
v.sort_by_key(|&(s, _)| std::cmp::Reverse(OrdF32(s)));
v
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn normalize_simple() {
let mut v = [3.0f32, 0.0, 0.0];
normalize(&mut v);
assert!((v[0] - 1.0).abs() < 1e-7, "x component should be 1");
assert!((v[1]).abs() < 1e-7);
assert!((v[2]).abs() < 1e-7);
}
#[test]
fn normalize_zero_vector_unchanged() {
let mut v = [0.0f32, 0.0, 0.0];
normalize(&mut v);
assert_eq!(v, [0.0, 0.0, 0.0]);
}
#[test]
fn normalize_near_zero_vector_unchanged() {
let mut v = [1e-14f32, 0.0, 0.0];
normalize(&mut v);
assert!((v[0] - 1e-14f32).abs() < 1e-20);
}
#[test]
fn normalize_produces_unit_vector() {
let mut v = [1.0f32, 2.0, 3.0];
normalize(&mut v);
let len: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(len - 1.0).abs() < 1e-6,
"normalized vector should have unit length, got {len}"
);
}
#[test]
fn normalize_negative_components() {
let mut v = [-4.0f32, 3.0];
normalize(&mut v);
let len: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((len - 1.0).abs() < 1e-6);
assert!(v[0] < 0.0, "sign should be preserved");
}
#[test]
fn normalize_already_unit() {
let mut v = [1.0f32, 0.0, 0.0];
normalize(&mut v);
assert!((v[0] - 1.0).abs() < 1e-7);
assert!((v[1]).abs() < 1e-7);
assert!((v[2]).abs() < 1e-7);
}
#[test]
fn normalize_single_element() {
let mut v = [5.0f32];
normalize(&mut v);
assert!((v[0] - 1.0).abs() < 1e-7);
}
#[test]
fn normalize_empty_slice() {
let mut v: [f32; 0] = [];
normalize(&mut v);
}
#[test]
fn dot_orthogonal_basis_vectors() {
let x = [1.0f32, 0.0, 0.0];
let y = [0.0f32, 1.0, 0.0];
let z = [0.0f32, 0.0, 1.0];
assert!(
(dot(&x, &y)).abs() < 1e-7,
"orthogonal vectors should have dot=0"
);
assert!((dot(&x, &z)).abs() < 1e-7);
assert!((dot(&y, &z)).abs() < 1e-7);
}
#[test]
fn dot_unit_vectors_equal_approx_one() {
let v = [1.0f32, 0.0, 0.0];
assert!((dot(&v, &v) - 1.0).abs() < 1e-7);
}
#[test]
fn dot_equal_unit_vectors() {
let mut a = [1.0f32, 1.0, 0.0];
normalize(&mut a);
let b = a;
assert!((dot(&a, &b) - 1.0).abs() < 1e-6);
}
#[test]
fn dot_antiparallel() {
let a = [1.0f32, 0.0];
let b = [-1.0f32, 0.0];
assert!(
(dot(&a, &b) + 1.0).abs() < 1e-7,
"antiparallel unit vectors → -1"
);
}
#[test]
fn dot_known_value() {
let a = [1.0f32, 2.0, 3.0];
let b = [4.0f32, 5.0, 6.0];
assert!((dot(&a, &b) - 32.0).abs() < 1e-5);
}
#[test]
fn dot_with_zero_vector() {
let a = [1.0f32, 2.0, 3.0];
let z = [0.0f32, 0.0, 0.0];
assert!((dot(&a, &z)).abs() < 1e-7);
}
#[test]
fn dot_empty_slices() {
let a: [f32; 0] = [];
let b: [f32; 0] = [];
assert_eq!(dot(&a, &b), 0.0);
}
#[test]
fn topk_k_zero_keeps_nothing() {
let mut tk: TopK<i32> = TopK::new(0);
tk.offer(1.0, 42);
tk.offer(2.0, 99);
assert!(tk.into_sorted_desc().is_empty());
}
#[test]
fn topk_fewer_than_k_offers() {
let mut tk: TopK<&str> = TopK::new(5);
tk.offer(0.9, "a");
tk.offer(0.5, "b");
let result = tk.into_sorted_desc();
assert_eq!(result.len(), 2);
assert!((result[0].0 - 0.9).abs() < 1e-7);
assert_eq!(result[0].1, "a");
assert!((result[1].0 - 0.5).abs() < 1e-7);
assert_eq!(result[1].1, "b");
}
#[test]
fn topk_keeps_top_k_of_many() {
let mut tk: TopK<usize> = TopK::new(3);
tk.offer(0.5, 1);
tk.offer(0.9, 2);
tk.offer(0.3, 3);
tk.offer(0.7, 4);
tk.offer(0.8, 5);
tk.offer(0.1, 6);
let result = tk.into_sorted_desc();
assert_eq!(result.len(), 3);
let scores: Vec<f32> = result.iter().map(|(s, _)| *s).collect();
assert!(
(scores[0] - 0.9).abs() < 1e-7,
"first should be 0.9, got {}",
scores[0]
);
assert!(
(scores[1] - 0.8).abs() < 1e-7,
"second should be 0.8, got {}",
scores[1]
);
assert!(
(scores[2] - 0.7).abs() < 1e-7,
"third should be 0.7, got {}",
scores[2]
);
}
#[test]
fn topk_into_sorted_desc_order() {
let mut tk: TopK<u8> = TopK::new(4);
tk.offer(0.2, 10);
tk.offer(0.8, 20);
tk.offer(0.5, 30);
tk.offer(0.1, 40);
let result = tk.into_sorted_desc();
for w in result.windows(2) {
assert!(
w[0].0 >= w[1].0,
"scores should be non-increasing: {} >= {}",
w[0].0,
w[1].0
);
}
}
#[test]
fn topk_exact_k_offers_all_kept() {
let mut tk: TopK<i32> = TopK::new(3);
tk.offer(0.1, 1);
tk.offer(0.2, 2);
tk.offer(0.3, 3);
let result = tk.into_sorted_desc();
assert_eq!(result.len(), 3);
}
#[test]
fn topk_nan_score_discarded_in_favor_of_real() {
let mut tk: TopK<i32> = TopK::new(2);
tk.offer(0.8, 1);
tk.offer(0.6, 2);
tk.offer(f32::NAN, 99);
let result = tk.into_sorted_desc();
assert_eq!(result.len(), 2, "NaN should not have been added");
let items: Vec<i32> = result.iter().map(|(_, i)| *i).collect();
assert!(items.contains(&1));
assert!(items.contains(&2));
assert!(!items.contains(&99), "NaN item should not be in results");
}
#[test]
fn topk_nan_when_heap_not_full_then_evicted_by_real() {
let mut tk: TopK<i32> = TopK::new(2);
tk.offer(f32::NAN, 99); tk.offer(0.5, 1); tk.offer(0.9, 2); let result = tk.into_sorted_desc();
assert_eq!(result.len(), 2);
let items: Vec<i32> = result.iter().map(|(_, i)| *i).collect();
assert!(items.contains(&1));
assert!(items.contains(&2));
assert!(!items.contains(&99), "NaN item should have been evicted");
}
#[test]
fn topk_negative_scores_ranked_correctly() {
let mut tk: TopK<&str> = TopK::new(2);
tk.offer(-0.5, "bad");
tk.offer(-0.1, "ok");
tk.offer(-0.9, "worse");
let result = tk.into_sorted_desc();
assert_eq!(result.len(), 2);
assert!((result[0].0 - (-0.1)).abs() < 1e-7);
assert_eq!(result[0].1, "ok");
assert!((result[1].0 - (-0.5)).abs() < 1e-7);
assert_eq!(result[1].1, "bad");
}
#[test]
fn topk_k_one_keeps_best() {
let mut tk: TopK<&str> = TopK::new(1);
tk.offer(0.3, "a");
tk.offer(0.7, "b");
tk.offer(0.5, "c");
let result = tk.into_sorted_desc();
assert_eq!(result.len(), 1);
assert_eq!(result[0].1, "b");
assert!((result[0].0 - 0.7).abs() < 1e-7);
}
#[test]
fn topk_cosine_workflow() {
let mut query = [1.0f32, 1.0, 0.0];
normalize(&mut query);
let stored = [
[1.0f32, 0.0, 0.0], [0.0f32, 1.0, 0.0], [0.0f32, 0.0, 1.0], ];
let mut tk: TopK<usize> = TopK::new(2);
for (i, row) in stored.iter().enumerate() {
let score = dot(row, &query);
tk.offer(score, i);
}
let result = tk.into_sorted_desc();
assert_eq!(result.len(), 2);
let rows: Vec<usize> = result.iter().map(|(_, r)| *r).collect();
assert!(rows.contains(&0), "row 0 should be in top-2");
assert!(rows.contains(&1), "row 1 should be in top-2");
assert!(
!rows.contains(&2),
"row 2 (orthogonal) should not be in top-2"
);
}
#[test]
fn topk_ties_all_kept_within_k() {
let mut tk: TopK<i32> = TopK::new(3);
for i in 0..5 {
tk.offer(0.5, i);
}
let result = tk.into_sorted_desc();
assert_eq!(result.len(), 3);
for (s, _) in &result {
assert!((*s - 0.5).abs() < 1e-7);
}
}
}