#[derive(Clone, Debug)]
struct Node {
key: u32,
node_id: usize,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct RadixHandle {
node_id: usize,
}
pub struct RadixHeap {
buckets: Vec<Vec<Node>>,
min_key: u32,
size: usize,
node_positions: Vec<Option<(usize, usize)>>,
}
impl RadixHeap {
pub fn new() -> Self {
let max_buckets = 33;
RadixHeap {
buckets: vec![Vec::new(); max_buckets],
min_key: 0,
size: 0,
node_positions: Vec::new(),
}
}
pub fn with_capacity(max_node_id: usize) -> Self {
let max_buckets = 33;
RadixHeap {
buckets: vec![Vec::new(); max_buckets],
min_key: 0,
size: 0,
node_positions: vec![None; max_node_id + 1],
}
}
fn bucket_index(&self, key: u32) -> usize {
debug_assert!(key >= self.min_key);
let delta = key ^ self.min_key;
(32 - delta.leading_zeros()) as usize
}
pub fn insert(&mut self, key: u32, node_id: usize) -> RadixHandle {
let bucket_idx = self.bucket_index(key);
let pos = self.buckets[bucket_idx].len();
self.buckets[bucket_idx].push(Node { key, node_id });
if node_id >= self.node_positions.len() {
self.node_positions.resize(node_id + 1, None);
}
self.node_positions[node_id] = Some((bucket_idx, pos));
self.size += 1;
RadixHandle { node_id }
}
pub fn extract_min(&mut self) -> Option<(u32, usize)> {
if self.size == 0 {
return None;
}
let first_bucket = self
.buckets
.iter()
.enumerate()
.find(|(_, bucket)| !bucket.is_empty())
.map(|(i, _)| i);
let bucket_idx = first_bucket?;
let (min_pos, min_key) = self.buckets[bucket_idx]
.iter()
.enumerate()
.min_by(|x, y| x.1.key.cmp(&y.1.key))
.map(|(pos, node)| (pos, node.key))
.unwrap_or((0, u32::MAX));
let extracted_node = self.buckets[bucket_idx].swap_remove(min_pos);
self.node_positions[extracted_node.node_id] = None;
let bucket_len = self.buckets[bucket_idx].len();
if min_pos < bucket_len {
let swapped_node = &self.buckets[bucket_idx][min_pos];
self.node_positions[swapped_node.node_id] = Some((bucket_idx, min_pos));
}
self.size -= 1;
self.min_key = min_key;
let nodes_to_redistribute = std::mem::take(&mut self.buckets[bucket_idx]);
for node in nodes_to_redistribute {
let node_id = node.node_id;
let new_bucket_idx = self.bucket_index(node.key);
let pos = self.buckets[new_bucket_idx].len();
self.buckets[new_bucket_idx].push(node);
self.node_positions[node_id] = Some((new_bucket_idx, pos));
}
Some((min_key, extracted_node.node_id))
}
pub fn decrease_key(&mut self, handle: &RadixHandle, new_key: u32) {
let node_id = handle.node_id;
if node_id >= self.node_positions.len() {
self.node_positions.resize(node_id + 1, None);
}
let (old_bucket_idx, old_pos) = match self.node_positions[node_id] {
Some(pos) => pos,
None => return, };
let current_key = self.buckets[old_bucket_idx][old_pos].key;
if new_key >= current_key {
return; }
let mut node = self.buckets[old_bucket_idx].swap_remove(old_pos);
let bucket_len = self.buckets[old_bucket_idx].len();
if old_pos < bucket_len {
let swapped_node = &self.buckets[old_bucket_idx][old_pos];
self.node_positions[swapped_node.node_id] = Some((old_bucket_idx, old_pos));
}
node.key = new_key;
let new_bucket_idx = self.bucket_index(new_key);
let new_pos = self.buckets[new_bucket_idx].len();
self.buckets[new_bucket_idx].push(node);
self.node_positions[node_id] = Some((new_bucket_idx, new_pos));
}
pub fn is_empty(&self) -> bool {
self.size == 0
}
pub fn len(&self) -> usize {
self.size
}
}
impl Default for RadixHeap {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_radix_heap_basic() {
let mut heap = RadixHeap::new();
assert_eq!(heap.extract_min(), None);
heap.insert(10, 1);
heap.insert(5, 2);
heap.insert(15, 3);
assert_eq!(heap.extract_min(), Some((5, 2)));
assert_eq!(heap.extract_min(), Some((10, 1)));
assert_eq!(heap.extract_min(), Some((15, 3)));
assert_eq!(heap.extract_min(), None);
}
#[test]
fn test_radix_heap_non_decreasing() {
let mut heap = RadixHeap::new();
heap.insert(0, 0);
assert_eq!(heap.extract_min(), Some((0, 0)));
heap.insert(5, 1);
heap.insert(10, 2);
assert_eq!(heap.extract_min(), Some((5, 1)));
heap.insert(15, 3);
assert_eq!(heap.extract_min(), Some((10, 2)));
assert_eq!(heap.extract_min(), Some((15, 3)));
}
#[test]
fn test_radix_heap_large_range() {
let mut heap = RadixHeap::new();
heap.insert(1000, 1);
heap.insert(1, 2);
heap.insert(500, 3);
assert_eq!(heap.extract_min(), Some((1, 2)));
assert_eq!(heap.extract_min(), Some((500, 3)));
assert_eq!(heap.extract_min(), Some((1000, 1)));
}
#[test]
fn test_radix_heap_dijkstra_like_sequence() {
let mut heap = RadixHeap::new();
let handles = [
heap.insert(0, 0),
heap.insert(u32::MAX, 1),
heap.insert(u32::MAX, 2),
heap.insert(u32::MAX, 3),
heap.insert(u32::MAX, 4),
];
assert_eq!(heap.extract_min(), Some((0, 0)));
heap.decrease_key(&handles[1], 10);
heap.decrease_key(&handles[2], 20);
heap.decrease_key(&handles[3], 30);
heap.decrease_key(&handles[4], 40);
assert_eq!(heap.extract_min(), Some((10, 1)));
heap.decrease_key(&handles[2], 15); heap.decrease_key(&handles[3], 25);
assert_eq!(heap.extract_min(), Some((15, 2)));
assert_eq!(heap.extract_min(), Some((25, 3)));
assert_eq!(heap.extract_min(), Some((40, 4)));
}
#[test]
fn test_radix_heap_invariant() {
let mut heap = RadixHeap::new();
heap.insert(7, 1);
heap.insert(8, 2);
assert_eq!(heap.extract_min(), Some((7, 1)));
heap.insert(9, 3);
assert_eq!(heap.extract_min(), Some((8, 2)));
assert_eq!(heap.extract_min(), Some((9, 3)));
}
}