use crate::errors::GraphError;
use crate::graph::traits::{GraphBase, GraphQuery};
use crate::graph::Graph;
use crate::node::NodeIndex;
use crate::GraphResult;
use std::collections::VecDeque;
pub fn dfs<T, F>(graph: &Graph<T, impl Clone>, start: NodeIndex, mut visitor: F)
where
F: FnMut(NodeIndex) -> bool,
{
let n = graph.node_count();
let mut visited = vec![false; n];
let mut stack = vec![start];
while let Some(node) = stack.pop() {
if visited[node.index()] {
continue;
}
visited[node.index()] = true;
if !visitor(node) {
return;
}
for neighbor in graph.neighbors(node) {
if !visited[neighbor.index()] {
stack.push(neighbor);
}
}
}
}
pub fn bfs<T, F>(graph: &Graph<T, impl Clone>, start: NodeIndex, mut visitor: F)
where
F: FnMut(NodeIndex, usize) -> bool,
{
let mut visited = vec![false; graph.node_count()];
let mut queue = VecDeque::new();
visited[start.index()] = true;
queue.push_back((start, 0));
while let Some((node, depth)) = queue.pop_front() {
if !visitor(node, depth) {
return;
}
for neighbor in graph.neighbors(node) {
if !visited[neighbor.index()] {
visited[neighbor.index()] = true;
queue.push_back((neighbor, depth + 1));
}
}
}
}
pub fn topological_sort<T>(graph: &Graph<T, impl Clone>) -> GraphResult<Vec<NodeIndex>> {
let n = graph.node_count();
let mut in_degree = vec![0usize; n];
let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
let index_to_node: std::collections::HashMap<usize, NodeIndex> =
node_indices.iter().map(|ni| (ni.index(), *ni)).collect();
for node in graph.nodes() {
for neighbor in graph.neighbors(node.index()) {
in_degree[neighbor.index()] += 1;
}
}
let mut queue = VecDeque::new();
for node in &node_indices {
if in_degree[node.index()] == 0 {
queue.push_back(*node);
}
}
let mut result = Vec::with_capacity(n);
while let Some(node_idx) = queue.pop_front() {
result.push(node_idx);
for neighbor in graph.neighbors(node_idx) {
if let Some(neighbor_ni) = index_to_node.get(&neighbor.index()) {
in_degree[neighbor_ni.index()] -= 1;
if in_degree[neighbor_ni.index()] == 0 {
queue.push_back(*neighbor_ni);
}
}
}
}
if result.len() != n {
Err(GraphError::GraphHasCycle)
} else {
Ok(result)
}
}
pub fn all_paths<T>(
graph: &Graph<T, impl Clone>,
source: NodeIndex,
target: NodeIndex,
) -> Vec<Vec<NodeIndex>> {
let mut result = Vec::new();
let mut path = vec![source];
let mut visited = vec![false; graph.node_count()];
visited[source.index()] = true;
fn dfs_helper<T>(
graph: &Graph<T, impl Clone>,
current: NodeIndex,
target: NodeIndex,
path: &mut Vec<NodeIndex>,
visited: &mut [bool],
result: &mut Vec<Vec<NodeIndex>>,
) {
if current == target {
result.push(path.clone());
return;
}
for neighbor in graph.neighbors(current) {
if !visited[neighbor.index()] {
visited[neighbor.index()] = true;
path.push(neighbor);
dfs_helper(graph, neighbor, target, path, visited, result);
path.pop();
visited[neighbor.index()] = false;
}
}
}
dfs_helper(graph, source, target, &mut path, &mut visited, &mut result);
result
}
pub fn tarjan_scc<T: Clone>(graph: &Graph<T, impl Clone>) -> Vec<Vec<NodeIndex>> {
let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
let n = node_indices.len();
if n == 0 {
return Vec::new();
}
let index_to_node: Vec<NodeIndex> = node_indices;
let mut index_counter = 0usize;
let mut stack: Vec<usize> = Vec::new();
let mut on_stack = vec![false; n];
let mut indices = vec![None; n];
let mut lowlinks = vec![0usize; n];
let mut sccs: Vec<Vec<NodeIndex>> = Vec::new();
#[allow(clippy::too_many_arguments)]
fn strongconnect<T, E>(
graph: &Graph<T, E>,
v: usize,
index_counter: &mut usize,
stack: &mut Vec<usize>,
on_stack: &mut [bool],
indices: &mut [Option<usize>],
lowlinks: &mut [usize],
sccs: &mut Vec<Vec<NodeIndex>>,
index_to_node: &[NodeIndex],
) where
T: Clone,
E: Clone,
{
indices[v] = Some(*index_counter);
lowlinks[v] = *index_counter;
*index_counter += 1;
stack.push(v);
on_stack[v] = true;
let node = index_to_node[v];
for neighbor in graph.neighbors(node) {
let w = neighbor.index();
if w >= index_to_node.len() {
continue;
}
match indices[w] {
None => {
strongconnect(
graph,
w,
index_counter,
stack,
on_stack,
indices,
lowlinks,
sccs,
index_to_node,
);
lowlinks[v] = lowlinks[v].min(lowlinks[w]);
}
Some(idx_w) if on_stack[w] => {
lowlinks[v] = lowlinks[v].min(idx_w);
}
_ => {}
}
}
if Some(lowlinks[v]) == indices[v] {
let mut scc = Vec::new();
loop {
let w = stack.pop().unwrap();
on_stack[w] = false;
scc.push(index_to_node[w]);
if w == v {
break;
}
}
sccs.push(scc);
}
}
for i in 0..n {
if indices[i].is_none() {
strongconnect(
graph,
i,
&mut index_counter,
&mut stack,
&mut on_stack,
&mut indices,
&mut lowlinks,
&mut sccs,
&index_to_node,
);
}
}
sccs
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::builders::GraphBuilder;
#[test]
fn test_dfs() {
let graph = GraphBuilder::directed()
.with_nodes(vec!["A", "B", "C", "D"])
.with_edges(vec![(0, 1, 1.0), (0, 2, 1.0), (1, 3, 1.0), (2, 3, 1.0)])
.build()
.unwrap();
let start = graph.nodes().next().unwrap().index();
let mut visited = Vec::new();
dfs(&graph, start, |node| {
visited.push(node.index());
true
});
assert_eq!(visited.len(), 4);
}
#[test]
fn test_bfs() {
let graph = GraphBuilder::directed()
.with_nodes(vec!["A", "B", "C", "D"])
.with_edges(vec![(0, 1, 1.0), (0, 2, 1.0), (1, 3, 1.0), (2, 3, 1.0)])
.build()
.unwrap();
let start = graph.nodes().next().unwrap().index();
let mut visited = Vec::new();
bfs(&graph, start, |node, _depth| {
visited.push(node.index());
true
});
assert_eq!(visited.len(), 4);
}
#[test]
fn test_topological_sort() {
let graph = GraphBuilder::directed()
.with_nodes(vec!["A", "B", "C", "D"])
.with_edges(vec![(0, 1, 1.0), (0, 2, 1.0), (1, 3, 1.0), (2, 3, 1.0)])
.build()
.unwrap();
let result = topological_sort(&graph).unwrap();
assert_eq!(result.len(), 4);
let pos: Vec<_> = result
.iter()
.map(|n| n.index())
.enumerate()
.flat_map(|(i, n)| vec![(n, i)])
.collect();
assert!(!pos.is_empty());
}
#[test]
fn test_tarjan_scc() {
let graph = GraphBuilder::directed()
.with_nodes(vec![0, 1, 2, 3, 4])
.with_edges(vec![
(0, 1, 1.0),
(1, 2, 1.0),
(2, 0, 1.0), (2, 3, 1.0), (3, 4, 1.0),
(4, 3, 1.0), ])
.build()
.unwrap();
let sccs = tarjan_scc(&graph);
assert_eq!(sccs.len(), 2);
let sizes: Vec<_> = sccs.iter().map(|scc| scc.len()).collect();
assert!(sizes.contains(&3));
assert!(sizes.contains(&2));
}
#[test]
fn test_tarjan_scc_single_node() {
let graph: Graph<i32, f64> = GraphBuilder::directed()
.with_nodes(vec![1])
.build()
.unwrap();
let sccs = tarjan_scc(&graph);
assert_eq!(sccs.len(), 1);
assert_eq!(sccs[0].len(), 1);
}
#[test]
fn test_tarjan_scc_empty_graph() {
let graph = GraphBuilder::<i32, f64>::directed().build().unwrap();
let sccs = tarjan_scc(&graph);
assert!(sccs.is_empty());
}
}