pub fn select_top_k<T, F>(slice: &mut [T], k: usize, mut compare: F) -> Vec<T>
where
F: FnMut(&T, &T) -> std::cmp::Ordering,
T: Clone,
{
let len = slice.len();
if len == 0 || k == 0 {
return Vec::new();
}
let k = k.min(len);
slice.select_nth_unstable_by(k - 1, &mut compare);
let mut top_k: Vec<T> = slice[..k].to_vec();
top_k.sort_by(&mut compare);
top_k
}
pub fn select_top_k_with_index<T, F>(
slice: &mut [(usize, T)],
k: usize,
mut compare: F,
) -> Vec<(usize, T)>
where
F: FnMut(&T, &T) -> std::cmp::Ordering,
T: Clone,
{
let len = slice.len();
if len == 0 || k == 0 {
return Vec::new();
}
let k = k.min(len);
slice.select_nth_unstable_by(k - 1, |a, b| compare(&a.1, &b.1));
let mut top_k: Vec<(usize, T)> = slice[..k].to_vec();
top_k.sort_by(|a, b| compare(&a.1, &b.1));
top_k
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_select_top_k_basic() {
let mut scores = vec![0.1_f32, 0.9, 0.3, 0.8, 0.5];
let top = select_top_k(&mut scores, 3, |a, b| {
b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)
});
assert_eq!(top.len(), 3);
assert_eq!(top, vec![0.9, 0.8, 0.5]);
}
#[test]
fn test_select_top_k_empty() {
let mut scores: Vec<f32> = vec![];
let top = select_top_k(&mut scores, 3, |a, b| {
b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)
});
assert!(top.is_empty());
}
#[test]
fn test_select_top_k_k_greater_than_n() {
let mut scores = vec![0.1_f32, 0.9];
let top = select_top_k(&mut scores, 5, |a, b| {
b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)
});
assert_eq!(top.len(), 2);
}
#[test]
fn test_select_top_k_with_index() {
let mut indexed: Vec<(usize, f32)> = vec![(0, 0.1), (1, 0.9), (2, 0.3), (3, 0.8)];
let top = select_top_k_with_index(&mut indexed, 2, |a, b| {
b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)
});
assert_eq!(top.len(), 2);
assert_eq!(top[0].0, 1); assert_eq!(top[1].0, 3); }
}