use std::hash::Hash;
use hashbrown::HashMap;
use fixedbitset::FixedBitSet;
use ndarray::prelude::*;
use petgraph::visit::{
GraphProp, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, NodeIndexable,
};
use petgraph::{Incoming, Outgoing};
use rayon::prelude::*;
pub fn distance_matrix<G>(
graph: G,
parallel_threshold: usize,
as_undirected: bool,
null_value: f64,
) -> Array2<f64>
where
G: Sync + IntoNeighborsDirected + NodeCount + NodeIndexable + IntoNodeIdentifiers + GraphProp,
G::NodeId: Hash + Eq + Sync,
{
let n = graph.node_count();
let node_map: HashMap<G::NodeId, usize> = if n != graph.node_bound() {
graph
.node_identifiers()
.enumerate()
.map(|(i, v)| (v, i))
.collect()
} else {
HashMap::new()
};
let node_map_inv: Vec<G::NodeId> = if n != graph.node_bound() {
graph.node_identifiers().collect()
} else {
Vec::new()
};
let mut node_map_fn: Box<dyn FnMut(G::NodeId) -> usize> = if n != graph.node_bound() {
Box::new(|n: G::NodeId| -> usize { node_map[&n] })
} else {
Box::new(|n: G::NodeId| -> usize { graph.to_index(n) })
};
let mut reverse_node_map: Box<dyn FnMut(usize) -> G::NodeId> = if n != graph.node_bound() {
Box::new(|n: usize| -> G::NodeId { node_map_inv[n] })
} else {
Box::new(|n: usize| -> G::NodeId { graph.from_index(n) })
};
let mut matrix = Array2::<f64>::from_elem((n, n), null_value);
let neighbors = if as_undirected {
(0..n)
.map(|index| {
graph
.neighbors_directed(reverse_node_map(index), Incoming)
.chain(graph.neighbors_directed(reverse_node_map(index), Outgoing))
.map(&mut node_map_fn)
.collect::<FixedBitSet>()
})
.collect::<Vec<_>>()
} else {
(0..n)
.map(|index| {
graph
.neighbors(reverse_node_map(index))
.map(&mut node_map_fn)
.collect::<FixedBitSet>()
})
.collect::<Vec<_>>()
};
let bfs_traversal = |start: usize, mut row: ArrayViewMut1<f64>| {
let mut distance = 0.0;
let mut seen = FixedBitSet::with_capacity(n);
let mut next = FixedBitSet::with_capacity(n);
let mut cur = FixedBitSet::with_capacity(n);
cur.put(start);
while !cur.is_clear() {
next.clear();
for found in cur.ones() {
row[[found]] = distance;
next |= &neighbors[found];
}
seen.union_with(&cur);
next.difference_with(&seen);
distance += 1.0;
::std::mem::swap(&mut cur, &mut next);
}
};
if n < parallel_threshold {
matrix
.axis_iter_mut(Axis(0))
.enumerate()
.for_each(|(index, row)| bfs_traversal(index, row));
} else {
matrix
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(index, row)| bfs_traversal(index, row));
}
matrix
}