use crate::index::GraphIndex;
use crate::interface::{GraphBase, StaticGraph};
use std::collections::BinaryHeap;
use std::marker::PhantomData;
pub type DefaultDijkstra<Graph> = Dijkstra<Graph, EpochNodeWeightArray<usize>>;
pub trait Weight {
fn infinity() -> Self;
}
impl Weight for usize {
#[inline]
fn infinity() -> Self {
Self::MAX
}
}
pub trait WeightedEdgeData {
fn weight(&self) -> usize;
}
impl WeightedEdgeData for usize {
#[inline]
fn weight(&self) -> usize {
*self
}
}
pub trait NodeWeightArray<WeightType> {
fn new(size: usize) -> Self;
fn get(&self, node_index: usize) -> WeightType;
fn get_mut(&mut self, node_index: usize) -> &mut WeightType;
fn set(&mut self, node_index: usize, weight: WeightType);
fn clear(&mut self);
}
impl<WeightType: Weight + Copy> NodeWeightArray<WeightType> for Vec<WeightType> {
fn new(size: usize) -> Self {
vec![WeightType::infinity(); size]
}
#[inline]
fn get(&self, node_index: usize) -> WeightType {
self[node_index]
}
#[inline]
fn get_mut(&mut self, node_index: usize) -> &mut WeightType {
&mut self[node_index]
}
#[inline]
fn set(&mut self, node_index: usize, weight: WeightType) {
self[node_index] = weight;
}
fn clear(&mut self) {
for entry in self.iter_mut() {
*entry = WeightType::infinity();
}
}
}
pub struct EpochArray {
epochs: Vec<u32>,
current_epoch: u32,
}
impl EpochArray {
pub fn new(len: usize) -> Self {
Self {
epochs: vec![0; len],
current_epoch: 1,
}
}
pub fn clear(&mut self) {
if self.current_epoch == u32::max_value() {
for epoch in self.epochs.iter_mut() {
*epoch = 0;
}
self.current_epoch = 1;
} else {
self.current_epoch += 1;
}
}
#[inline]
pub fn update(&mut self, index: usize) {
unsafe {
*self.epochs.get_unchecked_mut(index) = self.current_epoch;
}
}
#[inline]
pub fn get(&self, index: usize) -> bool {
self.epochs[index] == self.current_epoch
}
#[inline]
pub fn get_and_update(&mut self, index: usize) -> bool {
if self.epochs[index] == self.current_epoch {
true
} else {
self.epochs[index] = self.current_epoch;
false
}
}
}
pub struct EpochNodeWeightArray<WeightType> {
weights: Vec<WeightType>,
epochs: EpochArray,
}
impl<WeightType: Weight> EpochNodeWeightArray<WeightType> {
#[inline]
fn make_current(&mut self, node_index: usize) {
if !self.epochs.get_and_update(node_index) {
self.weights[node_index] = WeightType::infinity();
}
}
}
impl<WeightType: Weight + Copy> NodeWeightArray<WeightType> for EpochNodeWeightArray<WeightType> {
fn new(len: usize) -> Self {
Self {
weights: vec![WeightType::infinity(); len],
epochs: EpochArray::new(len),
}
}
#[inline]
fn get(&self, node_index: usize) -> WeightType {
if self.epochs.get(node_index) {
self.weights[node_index]
} else {
WeightType::infinity()
}
}
#[inline]
fn get_mut(&mut self, node_index: usize) -> &mut WeightType {
self.make_current(node_index);
&mut self.weights[node_index]
}
#[inline]
fn set(&mut self, node_index: usize, weight: WeightType) {
self.weights[node_index] = weight;
self.epochs.update(node_index);
}
fn clear(&mut self) {
self.epochs.clear();
}
}
pub trait DijkstraTargetMap<Graph: GraphBase> {
fn is_target(&self, node_index: Graph::NodeIndex) -> bool;
}
impl<Graph: GraphBase> DijkstraTargetMap<Graph> for Vec<bool> {
fn is_target(&self, node_index: Graph::NodeIndex) -> bool {
self[node_index.as_usize()]
}
}
pub struct Dijkstra<Graph: GraphBase, NodeWeights> {
queue: BinaryHeap<std::cmp::Reverse<(usize, Graph::NodeIndex)>>,
node_weights: NodeWeights,
graph: PhantomData<Graph>,
}
impl<
EdgeData: WeightedEdgeData,
Graph: StaticGraph<EdgeData = EdgeData>,
NodeWeights: NodeWeightArray<usize>,
> Dijkstra<Graph, NodeWeights>
{
pub fn new(graph: &Graph) -> Self {
Self {
queue: BinaryHeap::new(),
node_weights: NodeWeights::new(graph.node_count()),
graph: Default::default(),
}
}
#[inline(never)]
#[allow(clippy::too_many_arguments)]
pub fn shortest_path_lens<TargetMap: DijkstraTargetMap<Graph>>(
&mut self,
graph: &Graph,
source: Graph::NodeIndex,
targets: &TargetMap,
target_amount: usize,
max_weight: usize,
forbid_source_target: bool,
distances: &mut Vec<(Graph::NodeIndex, usize)>,
) {
self.queue.push(std::cmp::Reverse((0, source)));
self.node_weights.set(source.as_usize(), 0);
distances.clear();
while let Some(std::cmp::Reverse((weight, node_index))) = self.queue.pop() {
let actual_weight = self.node_weights.get(node_index.as_usize());
if actual_weight < weight {
continue;
}
debug_assert_eq!(actual_weight, weight);
if weight > max_weight {
break;
}
if targets.is_target(node_index) && (!forbid_source_target || node_index != source) {
distances.push((node_index, weight));
if distances.len() == target_amount {
break;
}
}
for out_neighbor in graph.out_neighbors(node_index) {
let new_neighbor_weight = weight + graph.edge_data(out_neighbor.edge_id).weight();
let neighbor_weight = self.node_weights.get_mut(out_neighbor.node_id.as_usize());
if new_neighbor_weight < *neighbor_weight {
*neighbor_weight = new_neighbor_weight;
self.queue.push(std::cmp::Reverse((
new_neighbor_weight,
out_neighbor.node_id,
)));
}
}
}
self.queue.clear();
self.node_weights.clear();
}
}
#[cfg(test)]
mod tests {
use crate::algo::dijkstra::DefaultDijkstra;
use crate::implementation::petgraph_impl;
use crate::interface::MutableGraphContainer;
#[test]
fn test_dijkstra_simple() {
let mut graph = petgraph_impl::new();
let n1 = graph.add_node(());
let n2 = graph.add_node(());
let n3 = graph.add_node(());
graph.add_edge(n1, n2, 2);
graph.add_edge(n2, n3, 2);
graph.add_edge(n1, n3, 5);
let mut dijkstra = DefaultDijkstra::new(&graph);
let mut distances = Vec::new();
let mut targets = vec![false, false, true];
dijkstra.shortest_path_lens(&graph, n1, &targets, 1, 6, false, &mut distances);
debug_assert_eq!(distances, vec![(n3, 4)]);
dijkstra.shortest_path_lens(&graph, n1, &targets, 1, 6, false, &mut distances);
debug_assert_eq!(distances, vec![(n3, 4)]);
dijkstra.shortest_path_lens(&graph, n2, &targets, 1, 6, false, &mut distances);
debug_assert_eq!(distances, vec![(n3, 2)]);
dijkstra.shortest_path_lens(&graph, n3, &targets, 1, 6, false, &mut distances);
debug_assert_eq!(distances, vec![(n3, 0)]);
targets = vec![false, true, false];
dijkstra.shortest_path_lens(&graph, n3, &targets, 1, 6, false, &mut distances);
debug_assert_eq!(distances, vec![]);
}
#[test]
fn test_dijkstra_cycle() {
let mut graph = petgraph_impl::new();
let n1 = graph.add_node(());
let n2 = graph.add_node(());
let n3 = graph.add_node(());
graph.add_edge(n1, n2, 2);
graph.add_edge(n2, n3, 2);
graph.add_edge(n3, n1, 5);
let mut dijkstra = DefaultDijkstra::new(&graph);
let mut distances = Vec::new();
let targets = vec![false, false, true];
dijkstra.shortest_path_lens(&graph, n1, &targets, 1, 6, false, &mut distances);
debug_assert_eq!(distances, vec![(n3, 4)]);
}
}