use alloc::vec::Vec;
use core::cmp::Ordering;
use core::mem::MaybeUninit;
#[diagnostic::on_unimplemented(
message = "`{Self}` is not a comparator for `LoserTree<{E}, _>`",
label = "missing `EntryComparator<{E}>` impl",
note = "implement `EntryComparator<{E}>` directly, or pass a closure of \
type `Fn(&{E}, &{E}) -> core::cmp::Ordering` — a blanket impl \
forwards every such closure to this trait automatically"
)]
pub trait EntryComparator<E> {
fn compare(&self, a: &E, b: &E) -> Ordering;
}
#[diagnostic::do_not_recommend]
impl<E, F> EntryComparator<E> for F
where
F: Fn(&E, &E) -> Ordering,
{
#[expect(
clippy::inline_always,
reason = "blanket forwarder must inline or the indirection this trait \
eliminates comes back; verified flat in disassembly"
)]
#[inline(always)]
fn compare(&self, a: &E, b: &E) -> Ordering {
(self)(a, b)
}
}
pub struct LoserTree<E, F> {
leaves: Vec<MaybeUninit<E>>,
present: Vec<u8>,
tree: Vec<usize>,
n_sources: usize,
active: usize,
cmp: F,
}
impl<E: core::fmt::Debug, F> core::fmt::Debug for LoserTree<E, F> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("LoserTree")
.field("n_sources", &self.n_sources)
.field("active", &self.active)
.field("cap", &self.leaves.len())
.finish_non_exhaustive()
}
}
impl<E, F> Drop for LoserTree<E, F> {
fn drop(&mut self) {
#[expect(
clippy::indexing_slicing,
reason = "leaves.len() == present.len() by construction (set in build())"
)]
for i in 0..self.leaves.len() {
if self.present[i] != 0 {
unsafe { self.leaves[i].assume_init_drop() };
}
}
}
}
impl<E, F> LoserTree<E, F>
where
F: EntryComparator<E>,
{
pub fn build(initial: Vec<Option<E>>, cmp: F) -> Self {
let n = initial.len();
let cap = n.next_power_of_two().max(2);
let mut leaves: Vec<MaybeUninit<E>> = Vec::with_capacity(cap);
let mut present: Vec<u8> = Vec::with_capacity(cap);
let mut active = 0_usize;
for item in initial {
if let Some(v) = item {
leaves.push(MaybeUninit::new(v));
present.push(1);
active += 1;
} else {
leaves.push(MaybeUninit::uninit());
present.push(0);
}
}
while leaves.len() < cap {
leaves.push(MaybeUninit::uninit());
present.push(0);
}
let tree = alloc::vec![0; cap];
let mut t = Self {
leaves,
present,
tree,
n_sources: n,
active,
cmp,
};
t.build_subtree(1, 0, cap);
t
}
#[inline]
#[expect(
dead_code,
reason = "part of the slot-set protocol, used by future callers"
)]
pub fn slots(&self) -> usize {
self.n_sources
}
#[expect(
clippy::inline_always,
reason = "called from winner_slot() on every merger step; forcing cross-crate inlining \
for the bench compilation unit measurably tightens the hot loop"
)]
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.active == 0
}
#[inline]
#[cfg_attr(
not(test),
expect(
dead_code,
reason = "diagnostic accessor used by unit tests and future callers"
)
)]
pub fn active_count(&self) -> usize {
self.active
}
#[expect(
clippy::inline_always,
reason = "hot per-step routine; cross-crate inlining for benches"
)]
#[inline(always)]
#[expect(
clippy::indexing_slicing,
reason = "tree[0] always exists: cap >= 2 by construction"
)]
pub fn winner_slot(&self) -> Option<usize> {
if self.is_empty() {
None
} else {
Some(self.tree[0])
}
}
#[inline]
#[expect(
clippy::indexing_slicing,
reason = "idx from winner_slot() is always < leaves.len()"
)]
#[cfg_attr(
not(test),
expect(
dead_code,
reason = "exposed for callers that want to inspect the winner without popping; \
seeking_merger reads the winner via winner_slot() since the slot index \
equals the source index by construction"
)
)]
pub fn peek_min(&self) -> Option<&E> {
let idx = self.winner_slot()?;
if self.present[idx] == 0 {
return None;
}
Some(unsafe { self.leaves[idx].assume_init_ref() })
}
#[expect(
clippy::expect_used,
reason = "empty-tree panic is the documented contract"
)]
#[expect(
clippy::indexing_slicing,
reason = "slot from winner_slot() is always < leaves.len()"
)]
pub fn replace_min(&mut self, new: E) -> E {
let slot = self
.winner_slot()
.expect("replace_min called on empty LoserTree");
debug_assert!(
self.present[slot] != 0,
"LoserTree winner slot must be present when winner_slot() returns Some"
);
let old = unsafe {
core::mem::replace(&mut self.leaves[slot], MaybeUninit::new(new)).assume_init()
};
self.replay(slot);
old
}
#[expect(
clippy::indexing_slicing,
reason = "slot from winner_slot() is always < leaves.len()"
)]
pub fn pop_min(&mut self) -> Option<E> {
let slot = self.winner_slot()?;
debug_assert!(
self.present[slot] != 0,
"LoserTree winner slot must be present when winner_slot() returns Some"
);
let old = unsafe {
core::mem::replace(&mut self.leaves[slot], MaybeUninit::uninit()).assume_init()
};
self.present[slot] = 0;
self.active -= 1;
self.replay(slot);
Some(old)
}
#[expect(
clippy::indexing_slicing,
reason = "slot is checked < self.leaves.len() before indexing"
)]
pub fn take_slot(&mut self, slot: usize) -> Option<E> {
if slot >= self.leaves.len() || self.present[slot] == 0 {
return None;
}
let taken = unsafe {
core::mem::replace(&mut self.leaves[slot], MaybeUninit::uninit()).assume_init()
};
self.present[slot] = 0;
self.active -= 1;
self.replay(slot);
Some(taken)
}
#[expect(
clippy::indexing_slicing,
reason = "node is always in 1..cap by recursive construction; tree[0] always in bounds"
)]
fn build_subtree(&mut self, node: usize, leaf_lo: usize, leaf_hi: usize) -> usize {
if leaf_hi - leaf_lo == 1 {
return leaf_lo;
}
let mid = usize::midpoint(leaf_lo, leaf_hi);
let left_winner = self.build_subtree(2 * node, leaf_lo, mid);
let right_winner = self.build_subtree(2 * node + 1, mid, leaf_hi);
let (winner, loser) = match self.cmp_indices(left_winner, right_winner) {
Ordering::Less | Ordering::Equal => (left_winner, right_winner),
Ordering::Greater => (right_winner, left_winner),
};
self.tree[node] = loser;
if node == 1 {
self.tree[0] = winner;
}
winner
}
#[expect(
clippy::inline_always,
reason = "the single hot per-merger-step routine; cross-crate inlining for benches"
)]
#[inline(always)]
#[expect(
clippy::indexing_slicing,
reason = "node traverses 1..cap by construction; tree[0] always exists"
)]
fn replay(&mut self, leaf: usize) {
let mut winner = leaf;
let mut node = usize::midpoint(self.leaves.len(), leaf);
while node >= 1 {
let other = self.tree[node];
if self.cmp_indices(other, winner) == Ordering::Less {
self.tree[node] = winner;
winner = other;
}
node >>= 1;
}
self.tree[0] = winner;
}
#[expect(
clippy::inline_always,
reason = "called O(log cap) per replay; cross-crate inlining for benches"
)]
#[inline(always)]
#[expect(
clippy::indexing_slicing,
reason = "callers (build_subtree, replay) only pass slot indices < cap; \
present.len() == leaves.len() == cap by construction"
)]
fn cmp_indices(&self, a: usize, b: usize) -> Ordering {
match (self.present[a] != 0, self.present[b] != 0) {
(true, true) => {
let va = unsafe { self.leaves[a].assume_init_ref() };
let vb = unsafe { self.leaves[b].assume_init_ref() };
self.cmp.compare(va, vb)
}
(true, false) => Ordering::Less,
(false, true) => Ordering::Greater,
(false, false) => Ordering::Equal,
}
}
}
#[cfg(test)]
#[expect(clippy::unwrap_used, reason = "test assertions")]
#[expect(
clippy::trivially_copy_pass_by_ref,
reason = "cmp_u32 signature mirrors the Fn(&E, &E) -> Ordering trait"
)]
#[expect(
clippy::cast_possible_truncation,
reason = "test sizes are small (n <= 33)"
)]
mod tests {
use super::*;
use test_log::test;
fn cmp_u32(a: &u32, b: &u32) -> Ordering {
a.cmp(b)
}
fn collect<F: Fn(&u32, &u32) -> Ordering>(mut t: LoserTree<u32, F>) -> Vec<u32> {
let mut out = Vec::new();
while let Some(v) = t.pop_min() {
out.push(v);
}
out
}
#[test]
fn empty_tree() {
let t: LoserTree<u32, fn(&u32, &u32) -> Ordering> =
LoserTree::build(alloc::vec![None, None, None], cmp_u32);
assert!(t.is_empty());
assert_eq!(t.active_count(), 0);
assert_eq!(t.peek_min(), None);
assert_eq!(t.winner_slot(), None);
}
#[test]
fn single_slot() {
let mut t = LoserTree::build(alloc::vec![Some(42_u32)], cmp_u32);
assert!(!t.is_empty());
assert_eq!(t.peek_min(), Some(&42));
assert_eq!(t.pop_min(), Some(42));
assert!(t.is_empty());
assert_eq!(t.pop_min(), None);
}
#[test]
fn drain_in_order() {
let t = LoserTree::build(alloc::vec![Some(3_u32), Some(1), Some(4), Some(2)], cmp_u32);
assert_eq!(collect(t), [1, 2, 3, 4]);
}
#[test]
fn non_pow2_padding() {
let t = LoserTree::build(
alloc::vec![Some(50_u32), Some(10), Some(40), Some(20), Some(30)],
cmp_u32,
);
assert_eq!(collect(t), [10, 20, 30, 40, 50]);
}
#[test]
fn replace_min_stays_winner_when_still_smallest() {
let mut t = LoserTree::build(
alloc::vec![Some(1_u32), Some(100), Some(200), Some(300)],
cmp_u32,
);
assert_eq!(t.replace_min(2), 1);
assert_eq!(t.peek_min(), Some(&2));
assert_eq!(t.winner_slot(), Some(0));
assert_eq!(t.replace_min(3), 2);
assert_eq!(t.peek_min(), Some(&3));
assert_eq!(t.winner_slot(), Some(0));
}
#[test]
fn replace_min_changes_winner() {
let mut t = LoserTree::build(alloc::vec![Some(1_u32), Some(5), Some(3), Some(7)], cmp_u32);
assert_eq!(t.winner_slot(), Some(0));
assert_eq!(t.replace_min(4), 1);
assert_eq!(t.peek_min(), Some(&3)); assert_eq!(t.winner_slot(), Some(2));
assert_eq!(t.replace_min(6), 3);
assert_eq!(t.peek_min(), Some(&4)); assert_eq!(t.winner_slot(), Some(0));
}
#[test]
fn pop_min_then_drain() {
let mut t = LoserTree::build(
alloc::vec![Some(10_u32), Some(20), Some(5), Some(15)],
cmp_u32,
);
assert_eq!(t.pop_min(), Some(5));
assert_eq!(t.active_count(), 3);
assert_eq!(t.peek_min(), Some(&10));
assert_eq!(collect(t), [10, 15, 20]);
}
#[test]
fn mixed_replace_and_pop() {
let mut t = LoserTree::build(alloc::vec![Some(1_u32), Some(2), Some(3), Some(4)], cmp_u32);
assert_eq!(t.replace_min(5), 1); assert_eq!(t.replace_min(6), 2); assert_eq!(t.pop_min(), Some(3));
assert_eq!(t.pop_min(), Some(4));
assert_eq!(t.pop_min(), Some(5));
assert_eq!(t.pop_min(), Some(6));
assert!(t.is_empty());
}
#[test]
fn reverse_comparator_gives_max_tree() {
let cmp = |a: &u32, b: &u32| b.cmp(a);
let mut t = LoserTree::build(alloc::vec![Some(1_u32), Some(4), Some(2), Some(3)], cmp);
assert_eq!(t.peek_min(), Some(&4)); assert_eq!(t.pop_min(), Some(4));
assert_eq!(t.pop_min(), Some(3));
assert_eq!(t.pop_min(), Some(2));
assert_eq!(t.pop_min(), Some(1));
}
#[test]
fn deterministic_tiebreak_by_cmp() {
let cmp = |a: &(u32, usize), b: &(u32, usize)| (a.0, a.1).cmp(&(b.0, b.1));
let mut t = LoserTree::build(
alloc::vec![Some((5_u32, 0)), Some((5, 1)), Some((5, 2)), Some((5, 3)),],
cmp,
);
assert_eq!(t.winner_slot(), Some(0));
let mut order = Vec::new();
while let Some((_, idx)) = t.pop_min() {
order.push(idx);
}
assert_eq!(order, [0, 1, 2, 3]);
}
#[test]
fn random_inputs_match_sorted_reference() {
use rand::SeedableRng;
use rand::seq::SliceRandom;
let mut rng = rand::rngs::StdRng::seed_from_u64(0xC0DE_F00D);
for n in [1_usize, 2, 3, 7, 8, 9, 31, 32, 33] {
for trial in 0..32 {
let mut all: Vec<u32> = (0..(n as u32 * 4)).collect();
all.shuffle(&mut rng);
let mut buckets: Vec<Vec<u32>> = (0..n).map(|_| Vec::new()).collect();
for (i, v) in all.iter().enumerate() {
#[expect(clippy::indexing_slicing, reason = "i % n always < n")]
buckets[i % n].push(*v);
}
for b in &mut buckets {
b.sort_unstable();
}
let mut reference = all.clone();
reference.sort_unstable();
let mut iters: Vec<std::vec::IntoIter<u32>> =
buckets.into_iter().map(IntoIterator::into_iter).collect();
let initial: Vec<Option<u32>> = iters.iter_mut().map(Iterator::next).collect();
let mut t = LoserTree::build(initial, cmp_u32);
let mut out = Vec::with_capacity(reference.len());
while let Some(slot) = t.winner_slot() {
#[expect(clippy::indexing_slicing, reason = "slot < n by construction")]
if let Some(next_val) = iters[slot].next() {
out.push(t.replace_min(next_val));
} else {
out.push(t.pop_min().unwrap());
}
}
assert_eq!(out, reference, "n={n} trial={trial}");
}
}
}
}