use bitvec::vec::BitVec;
use log::{debug, info};
use std::time::Instant;
use crate::graph::{EdgeID, Graph, INVALID_NODE_ID, NodeID};
pub struct DFS {
sources: Vec<NodeID>,
target_set: BitVec,
parents: Vec<NodeID>,
target: NodeID,
stack: Vec<usize>,
empty_target_set: bool,
}
impl DFS {
pub fn new(source_list: &[NodeID], target_list: &[NodeID], number_of_nodes: usize) -> Self {
let mut temp = Self {
sources: source_list.to_vec(),
target_set: BitVec::with_capacity(number_of_nodes),
parents: Vec::new(),
target: INVALID_NODE_ID,
stack: Vec::new(),
empty_target_set: target_list.is_empty(),
};
temp.target_set.resize(number_of_nodes, false);
for i in target_list {
temp.target_set.set(*i, true);
}
temp.populate_sources(number_of_nodes);
temp
}
fn populate_sources(&mut self, number_of_nodes: usize) {
self.parents.resize(number_of_nodes, INVALID_NODE_ID);
for s in &self.sources {
self.parents[*s] = *s;
}
}
pub fn run<T, G: Graph<T>>(&mut self, graph: &G) -> bool {
self.run_with_filter(graph, |_graph, _edge| false)
}
pub fn run_with_filter<T, F, G: Graph<T>>(&mut self, graph: &G, filter: F) -> bool
where
F: Fn(&G, EdgeID) -> bool,
{
let start = Instant::now();
self.stack.clear();
self.stack.extend(self.sources.iter().copied());
self.parents.fill(INVALID_NODE_ID);
for s in &self.sources {
self.parents[*s] = *s;
}
while let Some(node) = self.stack.pop() {
let node_is_source = self.parents[node] == node;
for edge in graph.edge_range(node) {
if filter(graph, edge) {
continue;
}
let target = graph.target(edge);
if self.parents[target] != INVALID_NODE_ID
|| (node_is_source && self.parents[target] == target)
{
continue;
}
self.parents[target] = node;
unsafe {
if *self.target_set.get_unchecked(target) {
self.target = target;
debug!("setting target {}", self.target);
let duration = start.elapsed();
info!("D/DFS took: {duration:?} (done)");
return true;
}
}
self.stack.push(target);
}
}
let duration = start.elapsed();
info!("DFS took: {duration:?} (done)");
self.empty_target_set
}
pub fn fetch_node_path(&self) -> Vec<NodeID> {
self.fetch_node_path_from_node(self.target)
}
pub fn fetch_node_path_from_node(&self, t: NodeID) -> Vec<NodeID> {
let mut id = t;
let mut path = Vec::new();
while id != self.parents[id] {
path.push(id);
id = self.parents[id];
}
path.push(id);
path.reverse();
path
}
pub fn fetch_edge_path<T>(&self, graph: &impl Graph<T>) -> Vec<EdgeID> {
let mut id = self.target;
let mut path = Vec::new();
while id != self.parents[id] {
let edge_id = graph.find_edge(self.parents[id], id).unwrap();
path.push(edge_id);
id = self.parents[id];
}
path.reverse();
path
}
pub fn path_iter(&self) -> PathIter {
PathIter::new(self)
}
}
pub struct PathIter<'a> {
dfs: &'a DFS,
id: usize,
}
impl PathIter<'_> {
pub fn new(dfs: &DFS) -> PathIter {
debug!("init: {}", dfs.target);
PathIter {
dfs,
id: dfs.target,
}
}
}
impl Iterator for PathIter<'_> {
type Item = NodeID;
fn next(&mut self) -> Option<NodeID> {
if self.id == INVALID_NODE_ID {
return None;
}
let result = self.id;
self.id = self.dfs.parents[self.id];
if result == self.dfs.parents[result] {
self.id = INVALID_NODE_ID;
}
Some(result)
}
}
#[cfg(test)]
mod tests {
use crate::edge::InputEdge;
use crate::graph::Graph;
use crate::{dfs::DFS, static_graph::StaticGraph};
#[test]
fn s_t_query_fetch_node_string() {
type Graph = StaticGraph<i32>;
let edges = vec![
InputEdge::new(0, 1, 3),
InputEdge::new(1, 2, 3),
InputEdge::new(4, 2, 1),
InputEdge::new(2, 3, 6),
InputEdge::new(0, 4, 2),
InputEdge::new(4, 5, 2),
InputEdge::new(5, 3, 7),
InputEdge::new(1, 5, 2),
];
let graph = Graph::new(edges);
let mut dfs = DFS::new(&[0], &[5], graph.number_of_nodes());
assert!(dfs.run(&graph));
let path = dfs.fetch_node_path();
assert_eq!(path, vec![0, 4, 5]);
let path: Vec<usize> = dfs.path_iter().collect();
assert_eq!(path, vec![5, 4, 0]);
}
#[test]
fn s_t_query_edge_list() {
type Graph = StaticGraph<i32>;
let edges = vec![
InputEdge::new(0, 1, 3),
InputEdge::new(1, 2, 3),
InputEdge::new(4, 2, 1),
InputEdge::new(2, 3, 6),
InputEdge::new(0, 4, 2),
InputEdge::new(4, 5, 2),
InputEdge::new(5, 3, 7),
InputEdge::new(1, 5, 2),
];
let graph = Graph::new(edges);
let mut dfs = DFS::new(&[0], &[5], graph.number_of_nodes());
assert!(dfs.run(&graph));
let path = dfs.fetch_edge_path(&graph);
assert_eq!(path, vec![1, 6]);
}
#[test]
fn s_all_query() {
type Graph = StaticGraph<i32>;
let edges = vec![
InputEdge::new(0, 1, 3),
InputEdge::new(1, 2, 3),
InputEdge::new(4, 2, 1),
InputEdge::new(2, 3, 6),
InputEdge::new(0, 4, 2),
InputEdge::new(4, 5, 2),
InputEdge::new(5, 3, 7),
InputEdge::new(1, 5, 2),
];
let graph = Graph::new(edges);
let mut dfs = DFS::new(&[0], &[], graph.number_of_nodes());
assert!(dfs.run(&graph));
let path = dfs.fetch_node_path_from_node(3);
assert_eq!(path, vec![0, 4, 5, 3]);
let path: Vec<usize> = dfs.path_iter().collect();
assert!(path.is_empty());
}
#[test]
fn multi_s_all_query() {
type Graph = StaticGraph<i32>;
let edges = vec![
InputEdge::new(0, 1, 3),
InputEdge::new(1, 2, 3),
InputEdge::new(4, 2, 1),
InputEdge::new(2, 3, 6),
InputEdge::new(0, 4, 2),
InputEdge::new(4, 5, 2),
InputEdge::new(5, 3, 7),
InputEdge::new(1, 5, 2),
];
let graph = Graph::new(edges);
let mut dfs = DFS::new(&[0, 1], &[], graph.number_of_nodes());
assert!(dfs.run(&graph));
let path = dfs.fetch_node_path_from_node(3);
assert_eq!(path, vec![1, 5, 3]);
let path: Vec<usize> = dfs.path_iter().collect();
assert!(path.is_empty());
}
}