use crate::storage::CsrGraph;
use crate::NodeId;
use anyhow::Result;
use std::collections::{HashSet, VecDeque};
pub fn find_callers(graph: &CsrGraph, target: NodeId, max_depth: usize) -> Result<Vec<u32>> {
let mut visited = HashSet::new();
let mut frontier = VecDeque::new();
let mut depth = 0;
frontier.push_back(target.0);
visited.insert(target.0);
while !frontier.is_empty() && depth < max_depth {
let level_size = frontier.len();
for _ in 0..level_size {
if let Some(current) = frontier.pop_front() {
let callers = graph.incoming_neighbors(NodeId(current))?;
for &caller in callers {
if !visited.contains(&caller) {
visited.insert(caller);
frontier.push_back(caller);
}
}
}
}
depth += 1;
}
visited.remove(&target.0);
Ok(visited.into_iter().collect())
}
pub fn bfs(graph: &CsrGraph, source: NodeId) -> Result<Vec<u32>> {
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(source.0);
visited.insert(source.0);
while let Some(current) = queue.pop_front() {
let neighbors = graph.outgoing_neighbors(NodeId(current))?;
for &neighbor in neighbors {
if !visited.contains(&neighbor) {
visited.insert(neighbor);
queue.push_back(neighbor);
}
}
}
Ok(visited.into_iter().collect())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_find_callers_direct() {
let edges = vec![(NodeId(0), NodeId(2), 1.0), (NodeId(1), NodeId(2), 1.0)];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
let callers = find_callers(&graph, NodeId(2), 1).unwrap();
assert_eq!(callers.len(), 2);
assert!(callers.contains(&0));
assert!(callers.contains(&1));
}
#[test]
fn test_find_callers_transitive() {
let edges = vec![(NodeId(0), NodeId(1), 1.0), (NodeId(1), NodeId(2), 1.0)];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
let callers = find_callers(&graph, NodeId(2), 10).unwrap();
assert_eq!(callers.len(), 2);
assert!(callers.contains(&0));
assert!(callers.contains(&1));
}
#[test]
fn test_find_callers_depth_limit() {
let edges = vec![
(NodeId(0), NodeId(1), 1.0),
(NodeId(1), NodeId(2), 1.0),
(NodeId(2), NodeId(3), 1.0),
];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
let callers = find_callers(&graph, NodeId(3), 1).unwrap();
assert_eq!(callers.len(), 1);
assert!(callers.contains(&2));
let callers = find_callers(&graph, NodeId(3), 2).unwrap();
assert_eq!(callers.len(), 2);
assert!(callers.contains(&1));
assert!(callers.contains(&2));
let callers = find_callers(&graph, NodeId(3), 10).unwrap();
assert_eq!(callers.len(), 3);
assert!(callers.contains(&0));
assert!(callers.contains(&1));
assert!(callers.contains(&2));
}
#[test]
fn test_bfs_simple() {
let edges = vec![(NodeId(0), NodeId(1), 1.0), (NodeId(1), NodeId(2), 1.0)];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
let reachable = bfs(&graph, NodeId(0)).unwrap();
assert_eq!(reachable.len(), 3);
assert!(reachable.contains(&0));
assert!(reachable.contains(&1));
assert!(reachable.contains(&2));
}
#[test]
fn test_bfs_disconnected() {
let edges = vec![
(NodeId(0), NodeId(1), 1.0),
(NodeId(2), NodeId(3), 1.0), ];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
let reachable = bfs(&graph, NodeId(0)).unwrap();
assert_eq!(reachable.len(), 2); assert!(reachable.contains(&0));
assert!(reachable.contains(&1));
assert!(!reachable.contains(&2)); }
}