use std::cmp::Ordering;
use crate::csr::CSR;
use crate::traits::{Integer, Scalar};
use anyhow::{format_err, Result};
#[derive(Clone, Copy, PartialEq)]
pub enum Connection {
Weak,
Strong,
}
pub fn connected_components<I: Integer, T: Scalar>(
csgraph: &CSR<I, T>,
mut directed: bool,
connection: Connection,
) -> Result<(usize, Vec<usize>)> {
if connection == Connection::Weak {
directed = false;
}
validate_graph(csgraph, directed)?;
let n = csgraph.cols();
let (n_components, labels) = if directed {
connected_components_directed(n, csgraph.colidx(), csgraph.rowptr())
} else {
let csgraph_t = csgraph.t().to_csr();
connected_components_undirected(
n,
csgraph.colidx(),
csgraph.rowptr(),
csgraph_t.colidx(),
csgraph_t.rowptr(),
)
};
Ok((n_components, labels))
}
macro_rules! lowlinks {
($labels:ident) => {
$labels
};
}
macro_rules! stack_f {
($ss:ident) => {
$ss
};
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum Node {
Label(usize),
Void,
End,
}
impl Node {
fn unwrap(self) -> usize {
match self {
Node::Label(l) => l,
Node::Void => panic!("called `Label::unwrap()` on a `Void` value"),
Node::End => panic!("called `Label::unwrap()` on a `End` value"),
}
}
}
impl PartialOrd for Node {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
match (self, other) {
(Node::Label(l1), Node::Label(l2)) => l1.partial_cmp(l2),
(Node::Label(_), Node::Void) => Some(Ordering::Greater),
(Node::Label(_), Node::End) => Some(Ordering::Greater),
(Node::Void, Node::Label(_)) => Some(Ordering::Less),
(Node::Void, Node::Void) => Some(Ordering::Equal),
(Node::Void, Node::End) => Some(Ordering::Greater),
(Node::End, Node::Label(_)) => Some(Ordering::Less),
(Node::End, Node::Void) => Some(Ordering::Less),
(Node::End, Node::End) => Some(Ordering::Equal),
}
}
}
fn connected_components_directed<I: Integer>(
n: usize,
indices: &[I],
indptr: &[I],
) -> (usize, Vec<usize>) {
let mut labels = vec![Node::Void; n];
let mut ss = vec![Node::Void; n];
let mut stack_b = vec![Node::Void; n];
let mut ss_head = Node::End;
let mut stack_head;
let mut index = 0;
let mut label = Node::Label(n - 1);
for v in 0..n {
if lowlinks![labels][v] == Node::Void {
stack_head = Node::Label(v);
stack_f![ss][v] = Node::End;
stack_b[v] = Node::End;
while stack_head != Node::End {
let v = stack_head.unwrap();
if lowlinks![labels][v] == Node::Void {
lowlinks![labels][v] = Node::Label(index);
index += 1;
for j in indptr[v].to_usize().unwrap()..indptr[v + 1].to_usize().unwrap() {
let w = indices[j].to_usize().unwrap();
if lowlinks![labels][w] == Node::Void {
if stack_f![ss][w] != Node::Void {
let f = stack_f![ss][w];
let b = stack_b[w];
if b != Node::End {
stack_f![ss][b.unwrap()] = f;
}
if f != Node::End {
stack_b[f.unwrap()] = b;
}
}
stack_f![ss][w] = stack_head;
stack_b[w] = Node::End;
stack_b[stack_head.unwrap()] = Node::Label(w);
stack_head = Node::Label(w);
}
}
} else {
stack_head = stack_f![ss][v];
if stack_head != Node::Void && stack_head != Node::End {
stack_b[stack_head.unwrap()] = Node::End;
}
stack_f![ss][v] = Node::Void;
stack_b[v] = Node::Void;
let mut root = true;
let mut low_v = lowlinks![labels][v];
for j in indptr[v].to_usize().unwrap()..indptr[v + 1].to_usize().unwrap() {
let low_w = lowlinks![labels][indices[j].to_usize().unwrap()];
if low_w < low_v {
low_v = low_w;
root = false;
}
}
lowlinks![labels][v] = low_v;
if root {
index -= 1;
while ss_head != Node::End
&& lowlinks![labels][v] <= lowlinks![labels][ss_head.unwrap()]
{
let w = ss_head.unwrap();
ss_head = ss[w];
ss[w] = Node::Void;
labels[w] = label;
index -= 1;
}
labels[v] = label;
label = match label {
Node::Label(l) => {
if l == 0 {
Node::Void
} else {
Node::Label(l - 1)
}
}
Node::Void => Node::End,
Node::End => {
unreachable!("label with End node mark");
}
}
} else {
ss[v] = ss_head;
ss_head = Node::Label(v);
}
}
}
}
}
(
match label {
Node::Label(l) => (n - 1) - l,
_ => n, },
labels.iter().map(|l| (n - 1) - l.unwrap()).collect(),
)
}
macro_rules! ss {
($labels:ident) => {
$labels
};
}
fn connected_components_undirected<I: Integer>(
n: usize,
indices1: &[I],
indptr1: &[I],
indices2: &[I],
indptr2: &[I],
) -> (usize, Vec<usize>) {
let mut labels = vec![Node::Void; n];
let mut label = 0;
let mut ss_head;
for v in 0..n {
if labels[v] == Node::Void {
ss_head = Node::Label(v);
ss![labels][v] = Node::End;
while ss_head != Node::End {
let v = ss_head.unwrap();
ss_head = ss![labels][v];
labels[v] = Node::Label(label);
for j in indptr1[v].to_usize().unwrap()..indptr1[v + 1].to_usize().unwrap() {
let w = indices1[j].to_usize().unwrap();
if ss![labels][w] == Node::Void {
ss![labels][w] = ss_head;
ss_head = Node::Label(w);
}
}
for j in indptr2[v].to_usize().unwrap()..indptr2[v + 1].to_usize().unwrap() {
let w = indices2[j].to_usize().unwrap();
if ss![labels][w] == Node::Void {
ss![labels][w] = ss_head;
ss_head = Node::Label(w);
}
}
}
label += 1;
}
}
(label, labels.iter().map(|l| l.unwrap()).collect())
}
pub fn cs_graph_components<I: Integer, T: Scalar, F: Integer + num_traits::Signed>(
n_nod: usize,
a_p: &[I],
a_j: &[I],
flag: &mut [F],
) -> Result<I> {
let mut pos = vec![I::one(); n_nod];
let mut n_comp = I::zero();
let mut n_stop = n_nod;
for ir in 0..n_nod {
flag[ir] = F::from(-1).unwrap();
if a_p[ir + 1] == a_p[ir] {
n_stop -= 1;
flag[ir] = F::from(-2).unwrap();
}
}
let mut n_tot = 0;
for icomp in 0..n_nod {
let mut ii = 0;
while (flag[ii] >= F::zero()) || (flag[ii] == F::from(-2).unwrap()) {
ii += 1;
if ii >= n_nod {
return Err(format_err!("graph is corrupted"));
}
}
flag[ii] = F::from(icomp).unwrap();
pos[0] = I::from(ii).unwrap();
let mut n_pos0 = 0;
let mut n_pos = 1;
let mut n_pos_new = n_pos;
for _ii in 0..n_nod {
let mut n_new = 0;
for ir in n_pos0..n_pos {
let pos_ir = pos[ir].to_usize().unwrap();
let start = a_p[pos_ir].to_usize().unwrap();
let end = a_p[pos_ir + 1].to_usize().unwrap();
for ic in start..end {
let aj_ic = a_j[ic].to_usize().unwrap();
if flag[aj_ic] == F::from(-1).unwrap() {
flag[aj_ic] = F::from(icomp).unwrap();
pos[n_pos_new] = a_j[ic];
n_pos_new += 1;
n_new += 1;
}
}
}
n_pos0 = n_pos;
n_pos = n_pos_new;
if n_new == 0 {
break;
}
}
n_tot += n_pos;
if n_tot == n_stop {
n_comp = I::from(icomp + 1).unwrap();
break;
}
}
Ok(n_comp)
}
pub fn depth_first_order<I: Integer, T: Scalar>(
csgraph: &CSR<I, T>,
i_start: usize,
directed: bool,
) -> Result<(Vec<usize>, Vec<Option<usize>>)> {
validate_graph(csgraph, directed)?;
let n = csgraph.cols();
let mut node_list = vec![0; n];
let mut predecessors = vec![None; n];
let length = if directed {
depth_first_directed(
i_start,
csgraph.colidx(),
csgraph.rowptr(),
&mut node_list,
&mut predecessors,
)
} else {
let csgraph_t = csgraph.t().to_csr();
depth_first_undirected(
i_start,
csgraph.colidx(),
csgraph.rowptr(),
csgraph_t.colidx(),
csgraph_t.rowptr(),
&mut node_list,
&mut predecessors,
)
};
Ok((node_list[..length].to_vec(), predecessors))
}
fn depth_first_directed<I: Integer>(
head_node: usize,
indices: &[I],
indptr: &[I],
node_list: &mut [usize],
predecessors: &mut [Option<usize>],
) -> usize {
let n = node_list.len();
let mut root_list = vec![0; n];
let mut flag = vec![false; n];
node_list[0] = head_node;
root_list[0] = head_node;
let mut i_root: isize = 0;
let mut i_nl_end = 1;
flag[head_node] = true;
while i_root >= 0 {
let pnode = root_list[i_root as usize];
let mut no_children = true;
for i in indptr[pnode].to_usize().unwrap()..indptr[pnode + 1].to_usize().unwrap() {
let cnode = indices[i].to_usize().unwrap();
if flag[cnode] {
continue;
} else {
i_root += 1;
root_list[i_root as usize] = cnode;
node_list[i_nl_end] = cnode;
predecessors[cnode] = Some(pnode);
flag[cnode] = true;
i_nl_end += 1;
no_children = false;
break;
}
}
if i_nl_end == n {
break;
}
if no_children {
i_root -= 1;
}
}
i_nl_end
}
fn depth_first_undirected<I: Integer>(
head_node: usize,
indices1: &[I],
indptr1: &[I],
indices2: &[I],
indptr2: &[I],
node_list: &mut [usize],
predecessors: &mut [Option<usize>],
) -> usize {
let n = node_list.len();
let mut root_list = vec![0; n];
let mut flag = vec![false; n];
node_list[0] = head_node;
root_list[0] = head_node;
let mut i_root: isize = 0;
let mut i_nl_end = 1;
flag[head_node] = true;
while i_root >= 0 {
let pnode = root_list[i_root as usize];
let mut no_children = true;
for i in indptr1[pnode].to_usize().unwrap()..indptr1[pnode + 1].to_usize().unwrap() {
let cnode = indices1[i].to_usize().unwrap();
if flag[cnode] {
continue;
} else {
i_root += 1;
root_list[i_root as usize] = cnode;
node_list[i_nl_end] = cnode;
predecessors[cnode] = Some(pnode);
flag[cnode] = true;
i_nl_end += 1;
no_children = false;
break;
}
}
if no_children {
for i in indptr2[pnode].to_usize().unwrap()..indptr2[pnode + 1].to_usize().unwrap() {
let cnode = indices2[i].to_usize().unwrap();
if flag[cnode] {
continue;
} else {
i_root += 1;
root_list[i_root as usize] = cnode;
node_list[i_nl_end] = cnode;
predecessors[cnode] = Some(pnode);
flag[cnode] = true;
i_nl_end += 1;
no_children = false;
break;
}
}
}
if i_nl_end == n {
break;
}
if no_children {
i_root -= 1
}
}
i_nl_end
}
fn validate_graph<I: Integer, T: Scalar>(csgraph: &CSR<I, T>, _directed: bool) -> Result<()> {
if csgraph.rows() != csgraph.cols() {
return Err(format_err!("compressed-sparse graph must be shape (N, N)"));
}
Ok(())
}
#[cfg(test)]
mod tests {
use std::iter::zip;
use crate::csr::CSR;
use crate::graph::{connected_components, Connection};
#[test]
fn test_weak_connections() {
let x_de = vec![vec![0, 1, 0], vec![0, 0, 0], vec![0, 0, 0]];
let x_sp = CSR::<usize, usize>::from_dense(&x_de);
let (n_components, labels) = connected_components(&x_sp, true, Connection::Weak).unwrap();
assert_eq!(n_components, 2);
zip(labels, [0, 0, 1]).for_each(|(a, e)| assert_eq!(a, e));
}
#[test]
fn test_strong_connections() {
let x1_de = vec![vec![0, 1, 0], vec![0, 0, 0], vec![0, 0, 0]];
let x1_sp = CSR::<usize, usize>::from_dense(&x1_de);
let x2_sp = &x1_sp + &x1_sp.t().to_csr();
let (n_components, mut labels) =
connected_components(&x1_sp, true, Connection::Strong).unwrap();
assert_eq!(n_components, 3);
labels.sort();
zip(labels, [0, 1, 2]).for_each(|(a, e)| assert_eq!(a, e));
let (n_components, mut labels) =
connected_components(&x2_sp, true, Connection::Strong).unwrap();
assert_eq!(n_components, 2);
labels.sort();
zip(labels, [0, 0, 1]).for_each(|(a, e)| assert_eq!(a, e));
}
#[test]
fn test_strong_connections2() {
let x = vec![
vec![0, 0, 0, 0, 0, 0],
vec![1, 0, 1, 0, 0, 0],
vec![0, 0, 0, 1, 0, 0],
vec![0, 0, 1, 0, 1, 0],
vec![0, 0, 0, 0, 0, 0],
vec![0, 0, 0, 0, 1, 0],
];
let x_sp = CSR::<usize, usize>::from_dense(&x);
let (n_components, mut labels) =
connected_components(&x_sp, true, Connection::Strong).unwrap();
assert_eq!(n_components, 5);
labels.sort();
zip(labels, [0, 1, 2, 2, 3, 4]).for_each(|(a, e)| assert_eq!(a, e));
}
#[test]
fn test_weak_connections2() {
let x = vec![
vec![0, 0, 0, 0, 0, 0],
vec![1, 0, 0, 0, 0, 0],
vec![0, 0, 0, 1, 0, 0],
vec![0, 0, 1, 0, 1, 0],
vec![0, 0, 0, 0, 0, 0],
vec![0, 0, 0, 0, 1, 0],
];
let x_sp = CSR::<usize, usize>::from_dense(&x);
let (n_components, mut labels) =
connected_components(&x_sp, true, Connection::Weak).unwrap();
assert_eq!(n_components, 2);
labels.sort();
zip(labels, [0, 0, 1, 1, 1, 1]).for_each(|(a, e)| assert_eq!(a, e));
}
}