use crate::graph::graph::Graph;
use std::collections::{HashMap, HashSet, VecDeque};
use std::fmt::Debug;
use std::hash::Hash;
#[derive(Debug)]
pub enum SearchError {
InvalidNode(usize),
}
impl std::fmt::Display for SearchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SearchError::InvalidNode(id) => {
write!(f, "Invalid node reference: node ID {} not found", id)
}
}
}
}
impl std::error::Error for SearchError {}
pub trait Search<W, N, E>
where
W: Copy + Default + PartialEq,
N: Clone + Eq + Hash + Debug,
E: Clone + Default + Debug,
{
fn has_path(&self, start: usize, target: usize) -> Result<bool, SearchError>;
fn bfs_path(&self, start: usize, target: usize) -> Result<Option<Vec<usize>>, SearchError>;
fn dfs(
&self,
current: usize,
target: usize,
visited: &mut HashSet<usize>,
) -> Result<bool, SearchError>;
fn has_node(&self, node: usize) -> Result<bool, SearchError>;
fn has_cycle(&self) -> Result<bool, SearchError>;
fn has_cycle_directed(
&self,
node: usize,
visited: &mut HashSet<usize>,
recursion_stack: &mut HashSet<usize>,
) -> Result<bool, SearchError>;
fn has_cycle_undirected(
&self,
node: usize,
parent: Option<usize>,
visited: &mut HashSet<usize>,
) -> Result<bool, SearchError>;
}
impl<W, N, E> Search<W, N, E> for Graph<W, N, E>
where
W: Copy + Default + PartialEq,
N: Clone + Eq + Hash + Debug,
E: Clone + Default + Debug,
{
fn has_path(&self, start: usize, target: usize) -> Result<bool, SearchError> {
if !self.nodes.contains(start) {
return Err(SearchError::InvalidNode(start));
}
if !self.nodes.contains(target) {
return Err(SearchError::InvalidNode(target));
}
let mut visited = HashSet::new();
self.dfs(start, target, &mut visited)
}
fn bfs_path(&self, start: usize, target: usize) -> Result<Option<Vec<usize>>, SearchError> {
if !self.nodes.contains(start) {
return Err(SearchError::InvalidNode(start));
}
if !self.nodes.contains(target) {
return Err(SearchError::InvalidNode(target));
}
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
let mut parent = HashMap::new();
queue.push_back(start);
visited.insert(start);
while let Some(current) = queue.pop_front() {
if current == target {
let mut path = vec![current];
let mut node = current;
while let Some(&p) = parent.get(&node) {
path.push(p);
node = p;
}
path.reverse();
return Ok(Some(path));
}
for &(neighbor, _) in &self.nodes[current].neighbors {
if !visited.contains(&neighbor) {
if !self.nodes.contains(neighbor) {
return Err(SearchError::InvalidNode(neighbor));
}
visited.insert(neighbor);
parent.insert(neighbor, current);
queue.push_back(neighbor);
}
}
}
Ok(None)
}
fn dfs(
&self,
current: usize,
target: usize,
visited: &mut HashSet<usize>,
) -> Result<bool, SearchError> {
if !self.nodes.contains(current) {
return Err(SearchError::InvalidNode(current));
}
if !self.nodes.contains(target) {
return Err(SearchError::InvalidNode(target));
}
if current == target {
return Ok(true);
}
visited.insert(current);
for &(neighbor, _) in &self.nodes[current].neighbors {
if !visited.contains(&neighbor) {
if !self.nodes.contains(neighbor) {
return Err(SearchError::InvalidNode(neighbor));
}
if self.dfs(neighbor, target, visited)? {
return Ok(true);
}
}
}
Ok(false)
}
fn has_node(&self, node: usize) -> Result<bool, SearchError> {
Ok(self.nodes.contains(node))
}
fn has_cycle(&self) -> Result<bool, SearchError> {
if self.directed {
let mut visited = HashSet::new();
let mut recursion_stack = HashSet::new();
for (node_id, _) in self.nodes.iter() {
if !visited.contains(&node_id)
&& self.has_cycle_directed(node_id, &mut visited, &mut recursion_stack)?
{
return Ok(true);
}
}
Ok(false)
} else {
let mut visited = HashSet::new();
for (node_id, _) in self.nodes.iter() {
if !visited.contains(&node_id)
&& self.has_cycle_undirected(node_id, None, &mut visited)?
{
return Ok(true);
}
}
Ok(false)
}
}
fn has_cycle_directed(
&self,
node: usize,
visited: &mut HashSet<usize>,
recursion_stack: &mut HashSet<usize>,
) -> Result<bool, SearchError> {
if !self.nodes.contains(node) {
return Err(SearchError::InvalidNode(node));
}
if recursion_stack.contains(&node) {
return Ok(true);
}
if visited.contains(&node) {
return Ok(false);
}
visited.insert(node);
recursion_stack.insert(node);
if let Some(neighbors) = self.nodes.get(node).map(|n| &n.neighbors) {
for &(neighbor, _) in neighbors {
if !self.nodes.contains(neighbor) {
return Err(SearchError::InvalidNode(neighbor));
}
if self.has_cycle_directed(neighbor, visited, recursion_stack)? {
return Ok(true);
}
}
}
recursion_stack.remove(&node);
Ok(false)
}
fn has_cycle_undirected(
&self,
node: usize,
parent: Option<usize>,
visited: &mut HashSet<usize>,
) -> Result<bool, SearchError> {
if !self.nodes.contains(node) {
return Err(SearchError::InvalidNode(node));
}
if visited.contains(&node) {
return Ok(true);
}
visited.insert(node);
if let Some(neighbors) = self.nodes.get(node).map(|n| &n.neighbors) {
for &(neighbor, _) in neighbors {
if Some(neighbor) == parent {
continue;
}
if !self.nodes.contains(neighbor) {
return Err(SearchError::InvalidNode(neighbor));
}
if self.has_cycle_undirected(neighbor, Some(node), visited)? {
return Ok(true);
}
}
}
Ok(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_has_path() {
let mut graph = Graph::<u32, (), ()>::new(false);
let n0 = graph.add_node(());
let n1 = graph.add_node(());
graph.add_edge(n0, n1, 1, ()).unwrap();
assert!(graph.has_path(n0, n1).unwrap());
assert!(graph.has_path(n1, n0).unwrap());
assert!(matches!(
graph.has_path(999, n0),
Err(SearchError::InvalidNode(999))
));
}
#[test]
fn test_bfs_path() {
let mut graph = Graph::<u32, (), ()>::new(true);
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
graph.add_edge(n0, n1, 1, ()).unwrap();
graph.add_edge(n1, n2, 1, ()).unwrap();
assert_eq!(graph.bfs_path(n0, n2).unwrap(), Some(vec![n0, n1, n2]));
assert!(matches!(
graph.bfs_path(999, n0),
Err(SearchError::InvalidNode(999))
));
}
#[test]
fn test_dfs() {
let mut graph = Graph::<u32, (), ()>::new(false);
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
graph.add_edge(n0, n1, 1, ()).unwrap();
graph.add_edge(n1, n2, 1, ()).unwrap();
let mut visited = HashSet::new();
assert!(graph.dfs(n0, n2, &mut visited).unwrap());
assert!(matches!(
graph.dfs(999, n0, &mut visited),
Err(SearchError::InvalidNode(999))
));
}
#[test]
fn test_invalid_nodes() {
let graph = Graph::<u32, (), ()>::new(false);
assert!(!graph.has_node(0).unwrap());
assert!(matches!(
graph.has_path(0, 1),
Err(SearchError::InvalidNode(0))
));
assert!(matches!(
graph.bfs_path(0, 1),
Err(SearchError::InvalidNode(0))
));
}
#[test]
fn test_cycle_detection_directed() {
let mut graph = Graph::<u32, (), ()>::new(true);
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
graph.add_edge(n0, n1, 1, ()).unwrap();
graph.add_edge(n1, n2, 1, ()).unwrap();
graph.add_edge(n2, n0, 1, ()).unwrap();
assert!(graph.has_cycle().unwrap());
}
#[test]
fn test_cycle_detection_undirected() {
let mut graph = Graph::<u32, (), ()>::new(false);
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
graph.add_edge(n0, n1, 1, ()).unwrap();
graph.add_edge(n1, n2, 1, ()).unwrap();
graph.add_edge(n2, n0, 1, ()).unwrap();
assert!(graph.has_cycle().unwrap());
}
#[test]
fn test_no_cycle_directed() {
let mut graph = Graph::<u32, (), ()>::new(true);
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
graph.add_edge(n0, n1, 1, ()).unwrap();
graph.add_edge(n1, n2, 1, ()).unwrap();
assert!(!graph.has_cycle().unwrap());
}
#[test]
fn test_no_cycle_undirected() {
let mut graph = Graph::<u32, (), ()>::new(false);
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
graph.add_edge(n0, n1, 1, ()).unwrap();
graph.add_edge(n1, n2, 1, ()).unwrap();
assert!(!graph.has_cycle().unwrap());
}
}