use crate::parallel::partitioner::Partition;
use crate::node::NodeIndex;
use crate::vgi::VirtualGraph;
use std::collections::HashMap;
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct DFSConfig {
pub record_path: bool,
pub max_depth: Option<usize>,
pub iterative: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DFSConfigError {
ZeroMaxDepth,
RecursionDepthExceeded,
}
pub const MAX_SAFE_RECURSION_DEPTH: usize = 1000;
impl std::fmt::Display for DFSConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DFSConfigError::ZeroMaxDepth => write!(f, "max_depth must be greater than 0"),
DFSConfigError::RecursionDepthExceeded => write!(
f,
"recursion depth exceeded safe limit ({}), use iterative mode instead",
MAX_SAFE_RECURSION_DEPTH
),
}
}
}
impl std::error::Error for DFSConfigError {}
impl Default for DFSConfig {
fn default() -> Self {
Self {
record_path: false,
max_depth: None,
iterative: true,
}
}
}
impl DFSConfig {
pub fn new() -> Self {
Self::default()
}
pub fn validate(&self) -> Result<(), DFSConfigError> {
if let Some(0) = self.max_depth {
return Err(DFSConfigError::ZeroMaxDepth);
}
if !self.iterative {
if let Some(max_d) = self.max_depth {
if max_d > MAX_SAFE_RECURSION_DEPTH {
return Err(DFSConfigError::RecursionDepthExceeded);
}
}
}
Ok(())
}
pub fn try_new() -> Result<Self, DFSConfigError> {
let config = Self::default();
config.validate()?;
Ok(config)
}
pub fn with_record_path(mut self, record: bool) -> Self {
self.record_path = record;
self
}
pub fn with_max_depth(mut self, max_depth: usize) -> Self {
self.max_depth = Some(max_depth);
self
}
pub fn with_iterative(mut self, iterative: bool) -> Self {
self.iterative = iterative;
self
}
}
#[derive(Debug, Clone)]
pub struct DFSResult {
pub start_node: NodeIndex,
discovery_time: Vec<usize>,
finish_time: Vec<usize>,
predecessors: Vec<Option<NodeIndex>>,
node_ids: Vec<NodeIndex>,
node_id_to_pos: Vec<usize>,
pub visited_count: usize,
pub max_depth_reached: usize,
pub computation_time_ms: u64,
pub partition_stats: Vec<PartitionDFSStats>,
}
impl DFSResult {
#[inline]
fn get_pos(&self, node: NodeIndex) -> Option<usize> {
self.node_id_to_pos
.get(node.index())
.copied()
.filter(|&pos| pos != usize::MAX)
}
pub fn is_visited(&self, node: NodeIndex) -> bool {
self.get_pos(node).is_some_and(|pos| self.discovery_time[pos] != usize::MAX)
}
pub fn discovery(&self, node: NodeIndex) -> Option<usize> {
self.get_pos(node)
.and_then(|pos| {
let d = self.discovery_time[pos];
if d != usize::MAX { Some(d) } else { None }
})
}
pub fn finish(&self, node: NodeIndex) -> Option<usize> {
self.get_pos(node)
.and_then(|pos| {
let f = self.finish_time[pos];
if f != usize::MAX { Some(f) } else { None }
})
}
pub fn reconstruct_path(&self, target: NodeIndex) -> Vec<NodeIndex> {
if !self.is_visited(target) {
return vec![];
}
let mut path = vec![target];
let mut current = target;
while let Some(pos) = self.get_pos(current) {
if let Some(Some(pred)) = self.predecessors.get(pos).copied() {
if pred == self.start_node {
path.push(pred);
break;
}
path.push(pred);
current = pred;
} else {
break;
}
}
path.reverse();
path
}
pub fn is_tree_edge(&self, from: NodeIndex, to: NodeIndex) -> bool {
self.get_pos(to)
.and_then(|pos| self.predecessors.get(pos).copied())
== Some(Some(from))
}
}
#[derive(Debug, Clone)]
pub struct PartitionDFSStats {
pub partition_id: usize,
pub visited_count: usize,
pub boundary_count: usize,
pub max_depth: usize,
}
pub struct DistributedDFS {
start_node: NodeIndex,
config: DFSConfig,
}
impl DistributedDFS {
pub fn new(start_node: NodeIndex) -> Self {
Self {
start_node,
config: DFSConfig::default(),
}
}
pub fn from_config(start_node: NodeIndex, config: DFSConfig) -> Self {
Self { start_node, config }
}
pub fn compute<G>(&self, graph: &G, partitions: &[Partition]) -> DFSResult
where
G: VirtualGraph<NodeData = (), EdgeData = ()>,
{
let start_time = Instant::now();
let all_nodes: Vec<NodeIndex> = partitions
.iter()
.flat_map(|p| p.nodes.iter().copied())
.collect();
let max_node_index = all_nodes.iter().map(|n| n.index()).max().unwrap_or(0);
let mut node_id_to_pos: Vec<usize> = vec![usize::MAX; max_node_index + 1];
for (i, &n) in all_nodes.iter().enumerate() {
node_id_to_pos[n.index()] = i;
}
let n = all_nodes.len();
let start_pos = match node_id_to_pos.get(self.start_node.index()).copied() {
Some(pos) if pos != usize::MAX => pos,
_ => return DFSResult {
start_node: self.start_node,
discovery_time: Vec::new(),
finish_time: Vec::new(),
predecessors: Vec::new(),
node_ids: Vec::new(),
node_id_to_pos: Vec::new(),
visited_count: 0,
max_depth_reached: 0,
computation_time_ms: 0,
partition_stats: partitions
.iter()
.map(|p| PartitionDFSStats {
partition_id: p.id,
visited_count: 0,
boundary_count: p.boundary_nodes.len(),
max_depth: 0,
})
.collect(),
},
};
let mut discovery_time: Vec<usize> = vec![usize::MAX; n];
let mut finish_time: Vec<usize> = vec![usize::MAX; n];
let mut predecessors: Vec<Option<NodeIndex>> = vec![None; n];
let mut visited: Vec<bool> = vec![false; n];
let mut time_counter = 0;
let mut max_depth_reached = 0;
if self.config.iterative {
self.iterative_dfs(
graph,
&self.start_node,
start_pos,
&node_id_to_pos,
&mut visited,
&mut discovery_time,
&mut finish_time,
&mut predecessors,
&mut time_counter,
&mut max_depth_reached,
);
} else {
self.recursive_dfs(
graph,
self.start_node,
start_pos,
&node_id_to_pos,
None,
0,
&mut visited,
&mut discovery_time,
&mut finish_time,
&mut predecessors,
&mut time_counter,
&mut max_depth_reached,
);
}
let computation_time_ms = start_time.elapsed().as_millis() as u64;
let visited_count = visited.iter().filter(|&&v| v).count();
let partition_stats: Vec<PartitionDFSStats> = partitions
.iter()
.map(|p| {
let visited_in_partition = p
.nodes
.iter()
.filter_map(|&n| node_id_to_pos.get(n.index()).copied())
.filter(|&pos| pos != usize::MAX && visited[pos])
.count();
let max_depth_in_partition = p
.nodes
.iter()
.filter_map(|&n| node_id_to_pos.get(n.index()).copied())
.filter(|&pos| pos != usize::MAX)
.filter_map(|pos| {
let d = discovery_time[pos];
if d != usize::MAX { Some(d) } else { None }
})
.max()
.unwrap_or(0);
PartitionDFSStats {
partition_id: p.id,
visited_count: visited_in_partition,
boundary_count: p.boundary_nodes.len(),
max_depth: max_depth_in_partition,
}
})
.collect();
DFSResult {
start_node: self.start_node,
discovery_time,
finish_time,
predecessors,
node_ids: all_nodes,
node_id_to_pos,
visited_count,
max_depth_reached,
computation_time_ms,
partition_stats,
}
}
#[allow(clippy::too_many_arguments)]
fn iterative_dfs<G>(
&self,
graph: &G,
start: &NodeIndex,
start_pos: usize,
node_id_to_pos: &[usize],
visited: &mut [bool],
discovery_time: &mut [usize],
finish_time: &mut [usize],
predecessors: &mut [Option<NodeIndex>],
time_counter: &mut usize,
max_depth: &mut usize,
) where
G: VirtualGraph<NodeData = (), EdgeData = ()>,
{
let mut stack: Vec<(NodeIndex, usize, usize, bool)> = Vec::new();
stack.push((*start, start_pos, 0, false));
visited[start_pos] = true;
predecessors[start_pos] = None;
while let Some((current, current_pos, depth, processed)) = stack.pop() {
if !processed {
*time_counter += 1;
discovery_time[current_pos] = *time_counter;
if depth > *max_depth {
*max_depth = depth;
}
if let Some(max_d) = self.config.max_depth {
if depth >= max_d {
*time_counter += 1;
finish_time[current_pos] = *time_counter;
continue;
}
}
stack.push((current, current_pos, depth, true));
for neighbor in graph.neighbors(current) {
if let Some(neighbor_pos) = node_id_to_pos.get(neighbor.index()).copied() {
if neighbor_pos != usize::MAX && !visited[neighbor_pos] {
visited[neighbor_pos] = true;
predecessors[neighbor_pos] = Some(current);
stack.push((neighbor, neighbor_pos, depth + 1, false));
}
}
}
} else {
*time_counter += 1;
finish_time[current_pos] = *time_counter;
}
}
}
#[allow(clippy::too_many_arguments)]
fn recursive_dfs<G>(
&self,
graph: &G,
current: NodeIndex,
current_pos: usize,
node_id_to_pos: &Vec<usize>,
_pred: Option<NodeIndex>,
depth: usize,
visited: &mut Vec<bool>,
discovery_time: &mut Vec<usize>,
finish_time: &mut Vec<usize>,
predecessors: &mut Vec<Option<NodeIndex>>,
time_counter: &mut usize,
max_depth: &mut usize,
) where
G: VirtualGraph<NodeData = (), EdgeData = ()>,
{
if depth > MAX_SAFE_RECURSION_DEPTH {
*time_counter += 1;
finish_time[current_pos] = *time_counter;
return;
}
visited[current_pos] = true;
*time_counter += 1;
discovery_time[current_pos] = *time_counter;
if depth > *max_depth {
*max_depth = depth;
}
if let Some(max_d) = self.config.max_depth {
if depth >= max_d {
*time_counter += 1;
finish_time[current_pos] = *time_counter;
return;
}
}
for neighbor in graph.neighbors(current) {
if let Some(neighbor_pos) = node_id_to_pos.get(neighbor.index()).copied() {
if neighbor_pos != usize::MAX && !visited[neighbor_pos] {
predecessors[neighbor_pos] = Some(current);
self.recursive_dfs(
graph,
neighbor,
neighbor_pos,
node_id_to_pos,
Some(current),
depth + 1,
visited,
discovery_time,
finish_time,
predecessors,
time_counter,
max_depth,
);
}
}
}
*time_counter += 1;
finish_time[current_pos] = *time_counter;
}
pub fn topological_sort<G>(&self, graph: &G, partitions: &[Partition]) -> Option<Vec<NodeIndex>>
where
G: VirtualGraph<NodeData = (), EdgeData = ()>,
{
let result = self.compute(graph, partitions);
let mut nodes: Vec<_> = result
.node_ids
.iter()
.filter(|&&n| result.is_visited(n))
.copied()
.collect();
nodes.sort_unstable_by(|a, b| {
result
.finish(*b)
.unwrap_or(0)
.cmp(&result.finish(*a).unwrap_or(0))
});
Some(nodes)
}
}
pub fn simple_dfs<G>(graph: &G, start: NodeIndex) -> HashMap<NodeIndex, usize>
where
G: VirtualGraph<NodeData = (), EdgeData = ()>,
{
let mut visited: Vec<bool> = Vec::new();
let mut time_counter = 0;
let all_nodes: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
let max_node_index = all_nodes.iter().map(|n| n.index()).max().unwrap_or(0);
let mut node_to_idx: Vec<usize> = vec![usize::MAX; max_node_index + 1];
for (i, &n) in all_nodes.iter().enumerate() {
node_to_idx[n.index()] = i;
}
let n = all_nodes.len();
visited.resize(n, false);
let mut discovery_time: Vec<usize> = vec![usize::MAX; n];
fn dfs_helper<G: VirtualGraph<NodeData = (), EdgeData = ()>>(
graph: &G,
current: NodeIndex,
node_to_idx: &Vec<usize>,
visited: &mut Vec<bool>,
discovery_time: &mut Vec<usize>,
time_counter: &mut usize,
) {
let idx = match node_to_idx.get(current.index()).copied() {
Some(i) if i != usize::MAX => i,
_ => return,
};
visited[idx] = true;
*time_counter += 1;
discovery_time[idx] = *time_counter;
for neighbor in graph.neighbors(current) {
if let Some(&n_idx) = node_to_idx.get(neighbor.index()) {
if n_idx != usize::MAX && !visited[n_idx] {
dfs_helper(graph, neighbor, node_to_idx, visited, discovery_time, time_counter);
}
}
}
}
let start_idx = node_to_idx.get(start.index()).copied().unwrap_or(usize::MAX);
if start_idx != usize::MAX {
dfs_helper(
graph,
start,
&node_to_idx,
&mut visited,
&mut discovery_time,
&mut time_counter,
);
}
let mut result = HashMap::new();
for (i, &d) in discovery_time.iter().enumerate() {
if d != usize::MAX {
result.insert(all_nodes[i], d);
}
}
result
}
pub fn tarjan_scc<G>(graph: &G) -> Vec<Vec<NodeIndex>>
where
G: VirtualGraph<NodeData = (), EdgeData = ()>,
{
let mut index_counter = 0;
let estimated_sccs = (graph.node_count() / 4).max(16);
let mut stack: Vec<NodeIndex> = Vec::with_capacity(graph.node_count());
let mut on_stack: Vec<bool> = Vec::with_capacity(graph.node_count());
let mut sccs: Vec<Vec<NodeIndex>> = Vec::with_capacity(estimated_sccs);
let all_nodes: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
let max_node_index = all_nodes.iter().map(|n| n.index()).max().unwrap_or(0);
let n = all_nodes.len();
let mut node_to_idx: Vec<usize> = vec![usize::MAX; max_node_index + 1];
for (i, &n) in all_nodes.iter().enumerate() {
node_to_idx[n.index()] = i;
}
let mut lowlinks: Vec<usize> = vec![usize::MAX; n];
let mut index: Vec<usize> = vec![usize::MAX; n];
on_stack.resize(n, false);
#[allow(clippy::too_many_arguments)]
fn strongconnect<G: VirtualGraph<NodeData = (), EdgeData = ()>>(
graph: &G,
v: NodeIndex,
v_idx: usize,
node_to_idx: &Vec<usize>,
on_stack: &mut Vec<bool>,
index_counter: &mut usize,
stack: &mut Vec<NodeIndex>,
lowlinks: &mut Vec<usize>,
index: &mut Vec<usize>,
sccs: &mut Vec<Vec<NodeIndex>>,
) {
*index_counter += 1;
index[v_idx] = *index_counter;
lowlinks[v_idx] = *index_counter;
stack.push(v);
on_stack[v_idx] = true;
for w in graph.neighbors(v) {
let w_idx = node_to_idx.get(w.index()).copied().unwrap_or(usize::MAX);
if w_idx == usize::MAX {
continue;
}
if index[w_idx] == usize::MAX {
strongconnect(
graph,
w,
w_idx,
node_to_idx,
on_stack,
index_counter,
stack,
lowlinks,
index,
sccs,
);
if lowlinks[w_idx] < lowlinks[v_idx] {
lowlinks[v_idx] = lowlinks[w_idx];
}
} else if on_stack[w_idx] {
if index[w_idx] < lowlinks[v_idx] {
lowlinks[v_idx] = index[w_idx];
}
}
}
if lowlinks[v_idx] == index[v_idx] {
let mut scc = Vec::new();
while let Some(w) = stack.pop() {
if let Some(&w_idx) = node_to_idx.get(w.index()) {
if w_idx != usize::MAX {
on_stack[w_idx] = false;
}
}
scc.push(w);
if w == v {
break;
}
}
sccs.push(scc);
}
}
for node_ref in graph.nodes() {
let node = node_ref.index();
let node_idx = node_to_idx.get(node.index()).copied().unwrap_or(usize::MAX);
if node_idx != usize::MAX && index[node_idx] == usize::MAX {
strongconnect(
graph,
node,
node_idx,
&node_to_idx,
&mut on_stack,
&mut index_counter,
&mut stack,
&mut lowlinks,
&mut index,
&mut sccs,
);
}
}
sccs
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parallel::partitioner::{HashPartitioner, Partitioner};
use crate::graph::Graph;
use crate::graph::traits::GraphOps;
#[test]
fn test_dfs_config() {
let config = DFSConfig::new()
.with_record_path(true)
.with_max_depth(10)
.with_iterative(false);
assert!(config.record_path);
assert_eq!(config.max_depth, Some(10));
assert!(!config.iterative);
}
#[test]
fn test_distributed_dfs_basic() {
let mut graph = Graph::<(), ()>::undirected();
let nodes: Vec<NodeIndex> = (0..10).map(|_| graph.add_node(()).unwrap()).collect();
for i in 0..nodes.len() - 1 {
graph.add_edge(nodes[i], nodes[i + 1], ()).unwrap();
}
let partitioner = HashPartitioner::new(2);
let partitions = partitioner.partition_graph(&graph);
let dfs = DistributedDFS::new(nodes[0]);
let result = dfs.compute(&graph, &partitions);
assert_eq!(result.visited_count, 10);
assert!(result.is_visited(nodes[0]));
assert!(result.is_visited(nodes[9]));
assert!(result.discovery(nodes[0]).unwrap() < result.discovery(nodes[9]).unwrap());
}
#[test]
fn test_distributed_dfs_max_depth() {
let mut graph = Graph::<(), ()>::undirected();
let nodes: Vec<NodeIndex> = (0..10).map(|_| graph.add_node(()).unwrap()).collect();
for i in 0..nodes.len() - 1 {
graph.add_edge(nodes[i], nodes[i + 1], ()).unwrap();
}
let partitioner = HashPartitioner::new(2);
let partitions = partitioner.partition_graph(&graph);
let dfs = DistributedDFS::from_config(nodes[0], DFSConfig::new().with_max_depth(3));
let result = dfs.compute(&graph, &partitions);
assert!(result.visited_count <= 4);
assert!(result.is_visited(nodes[0]));
assert!(result.is_visited(nodes[3]));
}
#[test]
fn test_dfs_result_reconstruct_path() {
let mut graph = Graph::<(), ()>::undirected();
let nodes: Vec<NodeIndex> = (0..5).map(|_| graph.add_node(()).unwrap()).collect();
for i in 0..nodes.len() - 1 {
graph.add_edge(nodes[i], nodes[i + 1], ()).unwrap();
}
let partitioner = HashPartitioner::new(2);
let partitions = partitioner.partition_graph(&graph);
let dfs = DistributedDFS::new(nodes[0]);
let result = dfs.compute(&graph, &partitions);
let path = result.reconstruct_path(nodes[4]);
assert!(!path.is_empty());
assert_eq!(path[0], nodes[0]);
assert_eq!(*path.last().unwrap(), nodes[4]);
}
#[test]
fn test_distributed_dfs_empty_graph() {
let graph = Graph::<(), ()>::undirected();
let partitioner = HashPartitioner::new(2);
let partitions = partitioner.partition_graph(&graph);
let dfs = DistributedDFS::new(NodeIndex::new_public(0));
let result = dfs.compute(&graph, &partitions);
assert_eq!(result.visited_count, 0);
assert_eq!(result.node_ids.len(), 0);
}
#[test]
fn test_distributed_dfs_isolated_nodes() {
let mut graph = Graph::<(), ()>::undirected();
let nodes: Vec<NodeIndex> = (0..5).map(|_| graph.add_node(()).unwrap()).collect();
graph.add_edge(nodes[0], nodes[1], ()).unwrap();
graph.add_edge(nodes[1], nodes[2], ()).unwrap();
let partitioner = HashPartitioner::new(2);
let partitions = partitioner.partition_graph(&graph);
let dfs = DistributedDFS::new(nodes[0]);
let result = dfs.compute(&graph, &partitions);
assert!(result.is_visited(nodes[0]));
assert!(result.is_visited(nodes[1]));
assert!(result.is_visited(nodes[2]));
assert!(!result.is_visited(nodes[3]));
assert!(!result.is_visited(nodes[4]));
}
#[test]
fn test_dfs_config_validation() {
let config = DFSConfig::new().with_max_depth(0);
assert_eq!(config.validate(), Err(DFSConfigError::ZeroMaxDepth));
let config = DFSConfig::new();
assert!(config.validate().is_ok());
let config = DFSConfig::new().with_max_depth(5);
assert!(config.validate().is_ok());
}
#[test]
fn test_dfs_try_new() {
let result = DFSConfig::try_new();
assert!(result.is_ok());
let config = DFSConfig::new().with_max_depth(0);
assert!(config.validate().is_err());
}
#[test]
fn test_simple_dfs() {
let mut graph = Graph::<(), ()>::undirected();
let nodes: Vec<NodeIndex> = (0..5).map(|_| graph.add_node(()).unwrap()).collect();
for i in 0..nodes.len() - 1 {
graph.add_edge(nodes[i], nodes[i + 1], ()).unwrap();
}
let discovery = simple_dfs(&graph, nodes[0]);
assert_eq!(discovery.len(), 5);
assert!(discovery.contains_key(&nodes[0]));
assert!(discovery.contains_key(&nodes[4]));
}
#[test]
fn test_dfs_disconnected_graph() {
let mut graph = Graph::<(), ()>::undirected();
let nodes: Vec<NodeIndex> = (0..6).map(|_| graph.add_node(()).unwrap()).collect();
graph.add_edge(nodes[0], nodes[1], ()).unwrap();
graph.add_edge(nodes[1], nodes[2], ()).unwrap();
let partitioner = HashPartitioner::new(2);
let partitions = partitioner.partition_graph(&graph);
let dfs = DistributedDFS::new(nodes[0]);
let result = dfs.compute(&graph, &partitions);
assert_eq!(result.visited_count, 3);
assert!(!result.is_visited(nodes[3]));
}
#[test]
fn test_partition_stats() {
let mut graph = Graph::<(), ()>::undirected();
for _ in 0..20 {
graph.add_node(()).unwrap();
}
let partitioner = HashPartitioner::new(4);
let partitions = partitioner.partition_graph(&graph);
let start_node = NodeIndex::new_public(0);
let dfs = DistributedDFS::new(start_node);
let result = dfs.compute(&graph, &partitions);
assert_eq!(result.partition_stats.len(), 4);
let total_visited: usize = result.partition_stats.iter().map(|s| s.visited_count).sum();
assert_eq!(total_visited, result.visited_count);
}
}