#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct DialHandle {
node_id: usize,
}
pub struct DialHeap {
buckets: Vec<Vec<usize>>,
current_bucket: usize,
max_distance: usize,
distances: Vec<u32>,
node_positions: Vec<Option<(usize, usize)>>,
size: usize,
}
const MAX_BUCKETS: usize = 10_000_000;
impl DialHeap {
pub fn new(max_nodes: usize, initial_max_distance: usize) -> Self {
DialHeap {
buckets: vec![Vec::new(); initial_max_distance + 1],
current_bucket: 0,
max_distance: 0,
distances: vec![u32::MAX; max_nodes],
node_positions: vec![None; max_nodes],
size: 0,
}
}
pub fn insert(&mut self, distance: u32, node_id: usize) -> DialHandle {
if node_id >= self.distances.len() {
self.distances.resize(node_id + 1, u32::MAX);
self.node_positions.resize(node_id + 1, None);
}
if distance == u32::MAX {
self.distances[node_id] = distance;
return DialHandle { node_id };
}
if distance < self.distances[node_id] {
if let Some((old_bucket, old_pos)) = self.node_positions[node_id] {
self.remove_from_bucket(old_bucket, old_pos);
}
let dist = distance as usize;
self.ensure_bucket_capacity(dist);
let pos = self.buckets[dist].len();
self.buckets[dist].push(node_id);
self.node_positions[node_id] = Some((dist, pos));
self.distances[node_id] = distance;
self.size += 1;
self.max_distance = self.max_distance.max(dist);
}
DialHandle { node_id }
}
pub fn extract_min(&mut self) -> Option<(u32, usize)> {
while self.current_bucket <= self.max_distance {
if let Some(node_id) = self.buckets[self.current_bucket].pop() {
self.node_positions[node_id] = None;
self.size -= 1;
return Some((self.distances[node_id], node_id));
}
self.current_bucket += 1;
}
None
}
pub fn decrease_key(&mut self, handle: &DialHandle, new_distance: u32) {
if new_distance == u32::MAX {
return;
}
let node_id = handle.node_id;
if node_id >= self.distances.len() {
self.distances.resize(node_id + 1, u32::MAX);
self.node_positions.resize(node_id + 1, None);
}
if new_distance >= self.distances[node_id] {
return;
}
if let Some((old_bucket, old_pos)) = self.node_positions[node_id] {
self.remove_from_bucket(old_bucket, old_pos);
}
let new_bucket = new_distance as usize;
self.ensure_bucket_capacity(new_bucket);
let new_pos = self.buckets[new_bucket].len();
self.buckets[new_bucket].push(node_id);
self.node_positions[node_id] = Some((new_bucket, new_pos));
self.distances[node_id] = new_distance;
self.size += 1;
self.max_distance = self.max_distance.max(new_bucket);
}
fn remove_from_bucket(&mut self, bucket_idx: usize, pos: usize) {
let bucket = &mut self.buckets[bucket_idx];
bucket.swap_remove(pos);
if pos < bucket.len() {
let swapped_node_id = bucket[pos];
self.node_positions[swapped_node_id] = Some((bucket_idx, pos));
}
self.size -= 1;
}
fn ensure_bucket_capacity(&mut self, dist: usize) {
if dist >= MAX_BUCKETS {
panic!(
"Distance {} exceeds maximum supported distance {} for Dial's algorithm. \
Consider using a different heap implementation for graphs with large edge weights.",
dist, MAX_BUCKETS
);
}
if dist >= self.buckets.len() {
self.buckets.resize(dist + 1, Vec::new());
}
}
pub fn is_empty(&self) -> bool {
self.size == 0
}
pub fn len(&self) -> usize {
self.size
}
}
impl Default for DialHeap {
fn default() -> Self {
Self::new(0, 10000)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dial_heap_basic() {
let mut heap = DialHeap::new(10, 100);
assert_eq!(heap.extract_min(), None);
heap.insert(10, 1);
heap.insert(5, 2);
heap.insert(15, 3);
assert_eq!(heap.extract_min(), Some((5, 2)));
assert_eq!(heap.extract_min(), Some((10, 1)));
assert_eq!(heap.extract_min(), Some((15, 3)));
assert_eq!(heap.extract_min(), None);
}
#[test]
fn test_dial_heap_decrease_key() {
let mut heap = DialHeap::new(10, 100);
let handle1 = heap.insert(20, 1);
let handle2 = heap.insert(30, 2);
heap.decrease_key(&handle1, 10);
heap.decrease_key(&handle2, 15);
assert_eq!(heap.extract_min(), Some((10, 1)));
assert_eq!(heap.extract_min(), Some((15, 2)));
}
#[test]
fn test_dial_heap_dijkstra_like() {
let mut heap = DialHeap::new(10, 1000);
let handles = [
heap.insert(0, 0),
heap.insert(u32::MAX, 1),
heap.insert(u32::MAX, 2),
heap.insert(u32::MAX, 3),
];
assert_eq!(heap.extract_min(), Some((0, 0)));
heap.decrease_key(&handles[1], 10);
heap.decrease_key(&handles[2], 20);
heap.decrease_key(&handles[3], 30);
assert_eq!(heap.extract_min(), Some((10, 1)));
assert_eq!(heap.extract_min(), Some((20, 2)));
assert_eq!(heap.extract_min(), Some((30, 3)));
}
#[test]
#[should_panic(expected = "exceeds maximum supported distance")]
fn test_dial_heap_max_buckets_insert() {
let mut heap = DialHeap::new(10, 100);
heap.insert(MAX_BUCKETS as u32, 1);
}
#[test]
#[should_panic(expected = "exceeds maximum supported distance")]
fn test_dial_heap_max_buckets_decrease_key() {
let mut heap = DialHeap::new(10, 100);
let handle = heap.insert(u32::MAX, 1);
heap.decrease_key(&handle, MAX_BUCKETS as u32);
}
}