use crate::algorithms::shortest_path::ShortestPathInfo;
use crate::{Error, ErrorKind, Graph};
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
use std::fmt::Display;
use std::hash::Hash;
use std::mem;
const SERIAL_TO_PARALLEL_THRESHOLD: usize = 20;
struct FringeNode {
pub node_index: usize,
pub count: i32,
pub distance: f64,
}
impl Ord for FringeNode {
fn cmp(&self, other: &Self) -> Ordering {
if self.distance < other.distance {
Ordering::Less
} else if self.distance > other.distance {
Ordering::Greater
} else {
let count_ordering = self.count.cmp(&other.count);
match count_ordering {
Ordering::Equal => self.node_index.cmp(&other.node_index),
_ => count_ordering,
}
}
}
}
impl PartialOrd for FringeNode {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for FringeNode {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
&& self.count == other.count
&& self.node_index == other.node_index
}
}
impl Eq for FringeNode {}
static CONTRADICTORY_PATHS_ERROR_MESSAGE: &str =
"Contradictary paths found, do some edges have negative weights?";
pub fn all_pairs<T, A>(
graph: &Graph<T, A>,
weighted: bool,
target: Option<T>,
cutoff: Option<f64>,
first_only: bool,
with_paths: bool,
) -> Result<HashMap<T, HashMap<T, ShortestPathInfo<T>>>, Error>
where
T: Hash + Eq + Clone + Ord + Display + Send + Sync,
A: Clone + Send + Sync,
{
if weighted {
graph.ensure_weighted()?;
}
let parallel =
graph.number_of_nodes() > SERIAL_TO_PARALLEL_THRESHOLD && rayon::current_num_threads() > 1;
let shortest_paths_vecs = match parallel {
true => {
let iterator =
all_pairs_par_iter(graph, weighted, target, cutoff, first_only, with_paths);
iterator.collect::<Vec<(usize, Vec<(usize, ShortestPathInfo<usize>)>)>>()
}
false => {
let iterator = all_pairs_iter(graph, weighted, target, cutoff, first_only, with_paths);
iterator.collect::<Vec<(usize, Vec<(usize, ShortestPathInfo<usize>)>)>>()
}
};
let x = shortest_paths_vecs
.into_iter()
.map(|(source, shortest_paths)| {
let source_name = graph.get_node_by_index(&source).unwrap().name.clone();
let shortest_paths_t = convert_shortest_path_info_vec_to_t_map(graph, shortest_paths);
(source_name, shortest_paths_t)
})
.collect();
Ok(x)
}
fn all_pairs_iter<'a, T, A>(
graph: &'a Graph<T, A>,
weighted: bool,
target: Option<T>,
cutoff: Option<f64>,
first_only: bool,
with_paths: bool,
) -> impl Iterator<Item = (usize, Vec<(usize, ShortestPathInfo<usize>)>)> + 'a
where
T: Hash + Eq + Clone + Ord + Display + Send + Sync,
A: Clone + Send + Sync,
{
let target_index = match target.clone() {
Some(t) => Some(graph.get_node_index(&t).unwrap()),
None => None,
};
let x = (0..graph.number_of_nodes())
.collect::<Vec<_>>()
.into_iter()
.map(move |node_index| {
let ss_index = match can_use_basic(target.clone(), cutoff, first_only, with_paths) {
true => dijkstra_basic(graph, weighted, node_index),
false => dijkstra(
graph,
weighted,
node_index,
target_index,
cutoff,
first_only,
with_paths,
),
}
.unwrap();
(node_index, ss_index)
});
x
}
pub(crate) fn all_pairs_par_iter<'a, T, A>(
graph: &'a Graph<T, A>,
weighted: bool,
target: Option<T>,
cutoff: Option<f64>,
first_only: bool,
with_paths: bool,
) -> rayon::iter::Map<
rayon::vec::IntoIter<usize>,
impl Fn(usize) -> (usize, Vec<(usize, ShortestPathInfo<usize>)>) + 'a,
>
where
T: Hash + Eq + Clone + Ord + Display + Send + Sync + 'a,
A: Clone + Send + Sync + 'a,
{
let target_index = match target.clone() {
Some(t) => Some(graph.get_node_index(&t).unwrap()),
None => None,
};
let x = (0..graph.number_of_nodes())
.collect::<Vec<_>>()
.into_par_iter()
.map(move |node_index| {
let ss_index = match can_use_basic(target.clone(), cutoff, first_only, with_paths) {
true => dijkstra_basic(graph, weighted, node_index),
false => dijkstra(
graph,
weighted,
node_index,
target_index,
cutoff,
first_only,
with_paths,
),
}
.unwrap();
(node_index, ss_index)
});
x
}
pub fn single_source<T, A>(
graph: &Graph<T, A>,
weighted: bool,
source: T,
target: Option<T>,
cutoff: Option<f64>,
first_only: bool,
with_paths: bool,
) -> Result<HashMap<T, ShortestPathInfo<T>>, Error>
where
T: Hash + Eq + Clone + Ord + Display + Send + Sync,
A: Clone + Send + Sync,
{
let source_index = graph.get_node_index(&source)?;
let target_index = match target.clone() {
Some(t) => Some(graph.get_node_index(&t)?),
None => None,
};
let result = match can_use_basic(target, cutoff, first_only, with_paths) {
true => dijkstra_basic(graph, weighted, source_index),
false => dijkstra(
graph,
weighted,
source_index,
target_index,
cutoff,
first_only,
with_paths,
),
}?;
Ok(convert_shortest_path_info_vec_to_t_map(graph, result))
}
pub fn multi_source<T, A>(
graph: &Graph<T, A>,
weighted: bool,
sources: Vec<T>,
target: Option<T>,
cutoff: Option<f64>,
first_only: bool,
with_paths: bool,
) -> Result<HashMap<T, HashMap<T, ShortestPathInfo<T>>>, Error>
where
T: Hash + Eq + Clone + Ord + Display + Send + Sync,
A: Clone + Send + Sync,
{
let parallel =
graph.number_of_nodes() > SERIAL_TO_PARALLEL_THRESHOLD && rayon::current_num_threads() > 1;
if !graph.has_nodes(&sources) {
return Err(Error {
kind: ErrorKind::NodeNotFound,
message: "One or more source nodes not found in graph".to_string(),
});
}
if target.is_some() && !graph.has_node(&target.clone().unwrap()) {
return Err(Error {
kind: ErrorKind::NodeNotFound,
message: "Target node not found in graph".to_string(),
});
}
let shortest_paths: Vec<(T, HashMap<T, ShortestPathInfo<T>>)> = match parallel {
true => sources
.into_par_iter()
.map(|source| {
(
source.clone(),
single_source(
graph,
weighted,
source.clone(),
target.clone(),
cutoff,
first_only,
with_paths,
)
.unwrap(),
)
})
.collect(),
false => sources
.into_iter()
.map(|source| {
(
source.clone(),
single_source(
graph,
weighted,
source.clone(),
target.clone(),
cutoff,
first_only,
with_paths,
)
.unwrap(),
)
})
.collect(),
};
Ok(shortest_paths.into_iter().collect())
}
fn dijkstra<T, A>(
graph: &Graph<T, A>,
weighted: bool,
source: usize,
target: Option<usize>,
cutoff: Option<f64>,
first_only: bool,
with_paths: bool,
) -> Result<Vec<(usize, ShortestPathInfo<usize>)>, Error>
where
T: Hash + Eq + Clone + Ord + Display + Send + Sync,
A: Clone,
{
let mut paths: Vec<Vec<Vec<usize>>> = vec![];
if with_paths {
paths = vec![vec![]; graph.number_of_nodes()];
paths[source] = vec![vec![source]];
}
let mut dist = vec![f64::MAX; graph.number_of_nodes()];
let mut seen = vec![f64::MAX; graph.number_of_nodes()];
let mut fringe = BinaryHeap::<FringeNode>::new();
let mut count = 0;
seen[source] = 0.0;
fringe.push(FringeNode {
node_index: source,
count: 0,
distance: -0.0,
});
while let Some(fringe_item) = fringe.pop() {
let d = -fringe_item.distance;
let v = fringe_item.node_index;
if dist[v] != f64::MAX {
continue;
}
dist[v] = d;
if target.as_ref() == Some(&v) {
break;
}
for adj in graph.get_successor_nodes_by_index(&v) {
let u = adj.node_index;
let cost = match weighted {
true => adj.weight,
false => 1.0,
};
let vu_dist = dist[v] + cost;
if cutoff.map_or(false, |c| vu_dist > c) {
continue;
}
if dist[u] != f64::MAX {
let u_dist = dist[u];
if vu_dist < u_dist {
return Err(get_contractory_paths_error());
}
} else if vu_dist < seen[u] {
seen[u] = vu_dist;
push_fringe_node(&mut count, &mut fringe, u, vu_dist);
if with_paths {
let mut new_paths_v = paths[v].clone();
new_paths_v.iter_mut().for_each(|pv| pv.push(u));
paths[u] = new_paths_v;
}
} else if !first_only && vu_dist == seen[u] {
push_fringe_node(&mut count, &mut fringe, u, vu_dist);
if with_paths {
add_u_to_v_paths_and_append_v_paths_to_u_paths(u, v, &mut paths);
}
}
}
}
Ok(get_shortest_path_infos::<T, A>(
dist, &mut paths, with_paths,
))
}
fn dijkstra_basic<T, A>(
graph: &Graph<T, A>,
weighted: bool,
source: usize,
) -> Result<Vec<(usize, ShortestPathInfo<usize>)>, Error>
where
T: Hash + Eq + Clone + Ord + Display + Send + Sync,
A: Clone,
{
let mut paths: Vec<Vec<Vec<usize>>> = vec![vec![]; graph.number_of_nodes()];
paths[source] = vec![vec![source]];
let mut dist = vec![f64::MAX; graph.number_of_nodes()];
let mut seen = vec![f64::MAX; graph.number_of_nodes()];
let mut fringe = BinaryHeap::<FringeNode>::new();
let mut count = 0;
seen[source] = 0.0;
fringe.push(FringeNode {
node_index: source,
count: 0,
distance: -0.0,
});
while let Some(fringe_item) = fringe.pop() {
let d = -fringe_item.distance;
let v = fringe_item.node_index;
if dist[v] != f64::MAX {
continue;
}
dist[v] = d;
for adj in graph.get_successor_nodes_by_index(&v) {
let u = adj.node_index;
let cost = match weighted {
true => adj.weight,
false => 1.0,
};
let vu_dist = dist[v] + cost;
if vu_dist < seen[u] {
seen[u] = vu_dist;
push_fringe_node(&mut count, &mut fringe, u, vu_dist);
} else if vu_dist == seen[u] {
push_fringe_node(&mut count, &mut fringe, u, vu_dist);
}
}
}
Ok(get_shortest_path_infos::<T, A>(dist, &mut paths, false))
}
#[inline]
fn get_contractory_paths_error() -> Error {
Error {
kind: ErrorKind::ContradictoryPaths,
message: CONTRADICTORY_PATHS_ERROR_MESSAGE.to_string(),
}
}
#[inline]
fn push_fringe_node(count: &mut i32, fringe: &mut BinaryHeap<FringeNode>, u: usize, vu_dist: f64) {
*count += 1;
fringe.push(FringeNode {
node_index: u,
count: *count,
distance: -vu_dist, });
}
#[inline]
fn add_u_to_v_paths_and_append_v_paths_to_u_paths(
u: usize,
v: usize,
paths: &mut Vec<Vec<Vec<usize>>>,
) {
let v_paths: Vec<Vec<usize>> = paths[v]
.iter()
.map(|p| {
let mut x = p.clone();
x.push(u.clone());
x
})
.collect();
for v_path in v_paths {
paths[u].push(v_path);
}
}
pub fn get_all_shortest_paths_involving<T, A>(
graph: &Graph<T, A>,
node_name: T,
weighted: bool,
) -> Vec<ShortestPathInfo<T>>
where
T: Hash + Eq + Clone + Ord + Display + Send + Sync,
A: Clone + Send + Sync,
{
let result = all_pairs(graph, weighted, None, None, false, true);
match result {
Err(_) => vec![],
Ok(pairs) => pairs
.into_iter()
.flat_map(|x| x.1.into_iter().map(|y| y.1))
.filter(|x| x.contains_path_through_node(node_name.clone()))
.collect(),
}
}
fn get_shortest_path_infos<T, A>(
distances: Vec<f64>,
paths: &mut Vec<Vec<Vec<usize>>>,
with_paths: bool,
) -> Vec<(usize, ShortestPathInfo<usize>)>
where
T: Hash + Eq + Clone + Ord + Display + Send + Sync,
A: Clone,
{
distances
.into_iter()
.enumerate()
.filter(|(_k, v)| *v != f64::MAX)
.map(|(k, v)| {
let paths = match with_paths {
true => mem::take(&mut paths[k]),
false => vec![],
};
(k, ShortestPathInfo { distance: v, paths })
})
.collect()
}
fn convert_shortest_path_info_index_to_t<T, A>(
graph: &Graph<T, A>,
spi: ShortestPathInfo<usize>,
) -> ShortestPathInfo<T>
where
T: Hash + Eq + Clone + Ord + Display + Send + Sync,
A: Clone + Send + Sync,
{
ShortestPathInfo::<T> {
distance: spi.distance,
paths: spi
.paths
.iter()
.map(|p| {
p.iter()
.map(|i| graph.get_node_by_index(i).unwrap().name.clone())
.collect()
})
.collect(),
}
}
fn convert_shortest_path_info_vec_to_t_map<T, A>(
graph: &Graph<T, A>,
spi_map: Vec<(usize, ShortestPathInfo<usize>)>,
) -> HashMap<T, ShortestPathInfo<T>>
where
T: Hash + Eq + Clone + Ord + Display + Send + Sync,
A: Clone + Send + Sync,
{
spi_map
.into_iter()
.map(|(k, v)| {
(
graph.get_node_by_index(&k).unwrap().name.clone(),
convert_shortest_path_info_index_to_t(graph, v),
)
})
.collect()
}
fn can_use_basic<T>(
target: Option<T>,
cutoff: Option<f64>,
first_only: bool,
with_paths: bool,
) -> bool {
target.is_none() && cutoff.is_none() && first_only == false && with_paths == false
}