use super::{num_vertices, to_adjacency_list, validate_graph};
use crate::error::{SparseError, SparseResult};
use crate::sparray::SparseArray;
use scirs2_core::ndarray::Array1;
use scirs2_core::numeric::{Float, SparseElement};
use std::fmt::Debug;
#[allow(dead_code)]
pub fn connected_components<T, S>(
graph: &S,
directed: bool,
connection: &str,
returnlabels: bool,
) -> SparseResult<(usize, Option<Array1<usize>>)>
where
T: Float + SparseElement + Debug + Copy + 'static,
S: SparseArray<T>,
{
validate_graph(graph, directed)?;
let connection_type = match connection.to_lowercase().as_str() {
"weak" => ConnectionType::Weak,
"strong" => ConnectionType::Strong,
_ => {
return Err(SparseError::ValueError(format!(
"Unknown connection type: {connection}. Use 'weak' or 'strong'"
)))
}
};
if directed {
match connection_type {
ConnectionType::Weak => weakly_connected_components(graph, returnlabels),
ConnectionType::Strong => strongly_connected_components(graph, returnlabels),
}
} else {
undirected_connected_components(graph, returnlabels)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum ConnectionType {
Weak,
Strong,
}
#[allow(dead_code)]
pub fn undirected_connected_components<T, S>(
graph: &S,
returnlabels: bool,
) -> SparseResult<(usize, Option<Array1<usize>>)>
where
T: Float + SparseElement + Debug + Copy + 'static,
S: SparseArray<T>,
{
let n = num_vertices(graph);
let adjlist = to_adjacency_list(graph, false)?;
let mut visited = vec![false; n];
let mut labels = if returnlabels {
Some(Array1::zeros(n))
} else {
None
};
let mut component_count = 0;
for start in 0..n {
if !visited[start] {
dfs_component(&adjlist, start, &mut visited, component_count, &mut labels);
component_count += 1;
}
}
Ok((component_count, labels))
}
#[allow(dead_code)]
pub fn weakly_connected_components<T, S>(
graph: &S,
returnlabels: bool,
) -> SparseResult<(usize, Option<Array1<usize>>)>
where
T: Float + SparseElement + Debug + Copy + 'static,
S: SparseArray<T>,
{
undirected_connected_components(graph, returnlabels)
}
#[allow(dead_code)]
pub fn strongly_connected_components<T, S>(
graph: &S,
returnlabels: bool,
) -> SparseResult<(usize, Option<Array1<usize>>)>
where
T: Float + SparseElement + Debug + Copy + 'static,
S: SparseArray<T>,
{
let n = num_vertices(graph);
let adjlist = to_adjacency_list(graph, true)?;
let mut tarjan = TarjanSCC::<T>::new(n, returnlabels);
for v in 0..n {
if tarjan.indices[v] == -1 {
tarjan.strongconnect(v, &adjlist);
}
}
Ok((tarjan.component_count, tarjan._labels))
}
#[allow(dead_code)]
fn dfs_component<T>(
adjlist: &[Vec<(usize, T)>],
start: usize,
visited: &mut [bool],
component_id: usize,
labels: &mut Option<Array1<usize>>,
) where
T: Float + SparseElement + Debug + Copy + 'static,
{
let mut stack = vec![start];
while let Some(node) = stack.pop() {
if visited[node] {
continue;
}
visited[node] = true;
if let Some(ref mut label_array) = labels {
label_array[node] = component_id;
}
for &(neighbor, _) in &adjlist[node] {
if !visited[neighbor] {
stack.push(neighbor);
}
}
}
}
struct TarjanSCC<T>
where
T: Float + SparseElement + Debug + Copy + 'static,
{
indices: Vec<isize>,
lowlinks: Vec<isize>,
on_stack: Vec<bool>,
stack: Vec<usize>,
index: isize,
component_count: usize,
_labels: Option<Array1<usize>>,
_phantom: std::marker::PhantomData<T>,
}
impl<T> TarjanSCC<T>
where
T: Float + SparseElement + Debug + Copy + 'static,
{
fn new(n: usize, returnlabels: bool) -> Self {
Self {
indices: vec![-1; n],
lowlinks: vec![-1; n],
on_stack: vec![false; n],
stack: Vec::new(),
index: 0,
component_count: 0,
_labels: if returnlabels {
Some(Array1::zeros(n))
} else {
None
},
_phantom: std::marker::PhantomData,
}
}
fn strongconnect(&mut self, v: usize, adjlist: &[Vec<(usize, T)>]) {
self.indices[v] = self.index;
self.lowlinks[v] = self.index;
self.index += 1;
self.stack.push(v);
self.on_stack[v] = true;
for &(w, _) in &adjlist[v] {
if self.indices[w] == -1 {
self.strongconnect(w, adjlist);
self.lowlinks[v] = self.lowlinks[v].min(self.lowlinks[w]);
} else if self.on_stack[w] {
self.lowlinks[v] = self.lowlinks[v].min(self.indices[w]);
}
}
if self.lowlinks[v] == self.indices[v] {
loop {
let w = self.stack.pop().expect("Operation failed");
self.on_stack[w] = false;
if let Some(ref mut labels) = self._labels {
labels[w] = self.component_count;
}
if w == v {
break;
}
}
self.component_count += 1;
}
}
}
#[allow(dead_code)]
pub fn is_connected<T, S>(graph: &S, directed: bool) -> SparseResult<bool>
where
T: Float + SparseElement + Debug + Copy + 'static,
S: SparseArray<T>,
{
let (n_components_, _) = connected_components(graph, directed, "strong", false)?;
Ok(n_components_ == 1)
}
#[allow(dead_code)]
pub fn largest_component<T, S>(
graph: &S,
directed: bool,
connection: &str,
) -> SparseResult<(usize, Vec<usize>)>
where
T: Float + SparseElement + Debug + Copy + 'static,
S: SparseArray<T>,
{
let (n_components, labels) = connected_components(graph, directed, connection, true)?;
let labels = labels.expect("Operation failed");
let mut component_sizes = vec![0; n_components];
for &label in labels.iter() {
component_sizes[label] += 1;
}
let largest_component_id = component_sizes
.iter()
.enumerate()
.max_by_key(|(_, &size)| size)
.map(|(id_, _)| id_)
.unwrap_or(0);
let largest_size = component_sizes[largest_component_id];
let largest_indices: Vec<usize> = labels
.iter()
.enumerate()
.filter_map(|(vertex, &label)| {
if label == largest_component_id {
Some(vertex)
} else {
None
}
})
.collect();
Ok((largest_size, largest_indices))
}
#[allow(dead_code)]
pub fn extract_largest_component<T, S>(
graph: &S,
directed: bool,
connection: &str,
) -> SparseResult<(S, Vec<usize>)>
where
T: Float + SparseElement + Debug + Copy + 'static,
S: SparseArray<T> + Clone,
{
let (_, vertex_indices) = largest_component(graph, directed, connection)?;
let mut old_to_new = vec![None; num_vertices(graph)];
for (new_idx, &old_idx) in vertex_indices.iter().enumerate() {
old_to_new[old_idx] = Some(new_idx);
}
let (row_indices, col_indices, values) = graph.find();
let mut new_rows = Vec::new();
let mut new_cols = Vec::new();
let mut new_values = Vec::new();
for (i, (&old_row, &old_col)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
if let (Some(new_row), Some(new_col)) = (old_to_new[old_row], old_to_new[old_col]) {
new_rows.push(new_row);
new_cols.push(new_col);
new_values.push(values[i]);
}
}
let original_n = num_vertices(graph);
let mut subgraph = graph.clone();
let (all_rows, all_cols, _) = subgraph.find();
for idx in 0..all_rows.len() {
let _ = subgraph.set(all_rows[idx], all_cols[idx], T::sparse_zero());
}
subgraph.eliminate_zeros();
for idx in 0..new_rows.len() {
if new_rows[idx] < original_n && new_cols[idx] < original_n {
let _ = subgraph.set(new_rows[idx], new_cols[idx], new_values[idx]);
}
}
Ok((subgraph, vertex_indices))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csr_array::CsrArray;
fn create_disconnected_graph() -> CsrArray<f64> {
let rows = vec![0, 1, 2, 3];
let cols = vec![1, 0, 3, 2];
let data = vec![1.0, 1.0, 1.0, 1.0];
CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).expect("Operation failed")
}
fn create_strongly_connected_graph() -> CsrArray<f64> {
let rows = vec![0, 1, 2];
let cols = vec![1, 2, 0];
let data = vec![1.0, 1.0, 1.0];
CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).expect("Operation failed")
}
#[test]
fn test_undirected_connected_components() {
let graph = create_disconnected_graph();
let (n_components, labels) =
undirected_connected_components(&graph, true).expect("Operation failed");
assert_eq!(n_components, 2);
let labels = labels.expect("Operation failed");
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[2], labels[3]);
assert_ne!(labels[0], labels[2]);
}
#[test]
fn test_connected_components_api() {
let graph = create_disconnected_graph();
let (n_components_, _) =
connected_components(&graph, false, "weak", false).expect("Operation failed");
assert_eq!(n_components_, 2);
let (n_components_, _) =
connected_components(&graph, true, "weak", false).expect("Operation failed");
assert_eq!(n_components_, 2);
}
#[test]
fn test_strongly_connected_components() {
let graph = create_strongly_connected_graph();
let (n_components, labels) =
strongly_connected_components(&graph, true).expect("Operation failed");
assert_eq!(n_components, 2);
let labels = labels.expect("Operation failed");
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[1], labels[2]);
assert_ne!(labels[0], labels[3]);
}
#[test]
fn test_is_connected() {
let disconnected = create_disconnected_graph();
assert!(!is_connected(&disconnected, false).expect("Operation failed"));
let rows = vec![0, 1, 1, 2];
let cols = vec![1, 0, 2, 1];
let data = vec![1.0, 1.0, 1.0, 1.0];
let connected =
CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
assert!(is_connected(&connected, false).expect("Operation failed"));
}
#[test]
fn test_largest_component() {
let rows = vec![0, 1, 1, 2, 3, 4];
let cols = vec![1, 0, 2, 1, 4, 3];
let data = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
let graph =
CsrArray::from_triplets(&rows, &cols, &data, (6, 6), false).expect("Operation failed");
let (size, indices) = largest_component(&graph, false, "weak").expect("Operation failed");
assert_eq!(size, 3);
assert_eq!(indices.len(), 3);
assert!(indices.contains(&0));
assert!(indices.contains(&1));
assert!(indices.contains(&2));
}
#[test]
fn test_single_component() {
let rows = vec![0, 1, 1, 2];
let cols = vec![1, 0, 2, 1];
let data = vec![1.0, 1.0, 1.0, 1.0];
let graph =
CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
let (n_components_, _) =
connected_components(&graph, false, "weak", false).expect("Operation failed");
assert_eq!(n_components_, 1);
let (size, indices) = largest_component(&graph, false, "weak").expect("Operation failed");
assert_eq!(size, 3);
assert_eq!(indices, vec![0, 1, 2]);
}
#[test]
fn test_isolated_vertices() {
let rows = vec![0, 1];
let cols = vec![1, 0];
let data = vec![1.0, 1.0];
let graph =
CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).expect("Operation failed");
let (n_components, labels) =
connected_components(&graph, false, "weak", true).expect("Operation failed");
assert_eq!(n_components, 3);
let labels = labels.expect("Operation failed");
assert_eq!(labels[0], labels[1]); assert_ne!(labels[0], labels[2]); assert_ne!(labels[0], labels[3]); assert_ne!(labels[2], labels[3]); }
}