use crate::graph::graph::Graph;
use std::collections::{HashSet, VecDeque};
use std::fmt::Debug;
use std::hash::Hash;
#[derive(Debug)]
pub enum ConnectivityError {
InvalidNode(usize),
TransposeError(String),
}
impl std::fmt::Display for ConnectivityError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConnectivityError::InvalidNode(id) => {
write!(f, "Invalid node reference: node ID {} not found", id)
}
ConnectivityError::TransposeError(msg) => {
write!(f, "Graph transposition failed: {}", msg)
}
}
}
}
impl std::error::Error for ConnectivityError {}
pub trait Connectivity<W, N, E> {
fn find_weakly_connected_components(&self) -> Result<Vec<Vec<usize>>, ConnectivityError>;
fn find_strongly_connected_components(&self) -> Result<Vec<Vec<usize>>, ConnectivityError>;
fn find_connected_components(&self) -> Result<Vec<Vec<usize>>, ConnectivityError> {
if self.is_directed().unwrap_or(false) {
self.find_strongly_connected_components()
} else {
self.find_weakly_connected_components()
}
}
fn is_weakly_connected(&self) -> Result<bool, ConnectivityError>;
fn is_strongly_connected(&self) -> Result<bool, ConnectivityError>;
fn is_connected(&self) -> Result<bool, ConnectivityError> {
if self.is_directed().unwrap_or(false) {
self.is_strongly_connected()
} else {
self.is_weakly_connected()
}
}
fn is_directed(&self) -> Result<bool, ConnectivityError>;
}
impl<W, N, E> Connectivity<W, N, E> for Graph<W, N, E>
where
W: Copy + Default + PartialEq + Debug,
N: Clone + Eq + Hash + Debug,
E: Clone + Default + Debug,
{
fn find_weakly_connected_components(&self) -> Result<Vec<Vec<usize>>, ConnectivityError> {
let mut visited = HashSet::new();
let mut components = Vec::new();
let predecessors = if self.directed {
self.predecessors()
} else {
std::collections::HashMap::new()
};
for node in self.nodes.iter().map(|(id, _)| id) {
if !visited.contains(&node) {
let mut component = Vec::new();
let mut queue = VecDeque::new();
queue.push_back(node);
visited.insert(node);
while let Some(current) = queue.pop_front() {
component.push(current);
let node_data = self
.nodes
.get(current)
.ok_or(ConnectivityError::InvalidNode(current))?;
let mut neighbors = node_data
.neighbors
.iter()
.map(|(n, _)| *n)
.collect::<Vec<_>>();
if let Some(preds) = predecessors.get(¤t) {
neighbors.extend(preds.iter().copied());
}
for neighbor in neighbors {
if !visited.contains(&neighbor) {
if !self.nodes.contains(neighbor) {
return Err(ConnectivityError::InvalidNode(neighbor));
}
visited.insert(neighbor);
queue.push_back(neighbor);
}
}
}
components.push(component);
}
}
Ok(components)
}
fn find_strongly_connected_components(&self) -> Result<Vec<Vec<usize>>, ConnectivityError> {
let mut visited = HashSet::new();
let mut order = Vec::with_capacity(self.nodes.len());
for node in self.nodes.iter().map(|(id, _)| id) {
if !visited.contains(&node) {
self.dfs_order(node, &mut visited, &mut order)?;
}
}
let predecessors = self.predecessors();
let mut visited = HashSet::new();
let mut components = Vec::new();
for &node in order.iter().rev() {
if !visited.contains(&node) {
let mut component = Vec::new();
self.dfs_collect_reverse(node, &predecessors, &mut visited, &mut component);
component.sort();
components.push(component);
}
}
Ok(components)
}
fn is_weakly_connected(&self) -> Result<bool, ConnectivityError> {
let components = self.find_weakly_connected_components()?;
Ok(components.len() == 1)
}
fn is_strongly_connected(&self) -> Result<bool, ConnectivityError> {
if self.nodes.is_empty() {
return Ok(true);
}
let start_node = self.nodes.iter().next().unwrap().0;
self.strong_connectivity_check(start_node)
}
fn is_directed(&self) -> Result<bool, ConnectivityError> {
Ok(self.directed)
}
}
impl<W, N, E> Graph<W, N, E>
where
W: Copy + Default + PartialEq + Debug,
N: Clone + Eq + Hash + Debug,
E: Clone + Default + Debug,
{
fn dfs_order(
&self,
node: usize,
visited: &mut HashSet<usize>,
order: &mut Vec<usize>,
) -> Result<(), ConnectivityError> {
visited.insert(node);
let neighbors = self
.nodes
.get(node)
.ok_or(ConnectivityError::InvalidNode(node))?
.neighbors
.iter();
for (neighbor, _) in neighbors {
if !visited.contains(neighbor) {
self.dfs_order(*neighbor, visited, order)?;
}
}
order.push(node);
Ok(())
}
fn dfs_collect(
&self,
node: usize,
visited: &mut HashSet<usize>,
component: &mut Vec<usize>,
) -> Result<(), ConnectivityError> {
visited.insert(node);
component.push(node);
let neighbors = self
.nodes
.get(node)
.ok_or(ConnectivityError::InvalidNode(node))?
.neighbors
.iter();
for (neighbor, _) in neighbors {
if !visited.contains(neighbor) {
self.dfs_collect(*neighbor, visited, component)?;
}
}
Ok(())
}
fn strong_connectivity_check(&self, start: usize) -> Result<bool, ConnectivityError> {
let mut forward_visited = HashSet::new();
self.dfs_collect(start, &mut forward_visited, &mut vec![])?;
if forward_visited.len() != self.nodes.len() {
return Ok(false);
}
let predecessors = self.predecessors();
let mut backward_visited = HashSet::new();
self.dfs_collect_reverse(start, &predecessors, &mut backward_visited, &mut vec![]);
Ok(backward_visited.len() == self.nodes.len())
}
fn predecessors(&self) -> std::collections::HashMap<usize, Vec<usize>> {
let mut predecessors: std::collections::HashMap<usize, Vec<usize>> =
std::collections::HashMap::new();
for (id, _) in self.nodes.iter() {
predecessors.entry(id).or_default();
}
for (_, edge) in self.edges.iter() {
predecessors.entry(edge.to).or_default().push(edge.from);
}
predecessors
}
fn dfs_collect_reverse(
&self,
node: usize,
predecessors: &std::collections::HashMap<usize, Vec<usize>>,
visited: &mut HashSet<usize>,
component: &mut Vec<usize>,
) {
let mut stack = vec![node];
while let Some(current) = stack.pop() {
if !visited.insert(current) {
continue;
}
component.push(current);
if let Some(preds) = predecessors.get(¤t) {
for &pred in preds {
if !visited.contains(&pred) {
stack.push(pred);
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_strongly_connected() {
let mut graph = Graph::<u32, (), ()>::new(true);
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
graph.add_edge(n0, n1, 1, ()).unwrap();
graph.add_edge(n1, n2, 1, ()).unwrap();
graph.add_edge(n2, n0, 1, ()).unwrap();
assert!(graph.is_strongly_connected().unwrap());
let scc = graph.find_strongly_connected_components().unwrap();
assert_eq!(scc.len(), 1);
}
#[test]
fn test_weak_connectivity() {
let mut graph = Graph::<u32, (), ()>::new(true);
let n0 = graph.add_node(());
let n1 = graph.add_node(());
graph.add_edge(n0, n1, 1, ()).unwrap();
assert!(!graph.is_strongly_connected().unwrap());
assert!(graph.is_weakly_connected().unwrap());
}
#[test]
fn test_scc_with_sparse_ids() {
let mut graph = Graph::<u32, (), ()>::new(true);
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
let n3 = graph.add_node(());
graph.remove_node(n1).unwrap();
graph.add_edge(n0, n2, 1, ()).unwrap();
graph.add_edge(n2, n3, 1, ()).unwrap();
graph.add_edge(n3, n0, 1, ()).unwrap();
let scc = graph.find_strongly_connected_components().unwrap();
assert_eq!(scc.len(), 1);
assert_eq!(scc[0], vec![n0, n2, n3]);
assert!(graph.is_strongly_connected().unwrap());
}
}