use std::collections::{HashSet, VecDeque};
use crate::backend::GraphBackend;
use crate::errors::SqliteGraphError;
pub trait GraphIterator: Iterator<Item = Result<i64, SqliteGraphError>> {
fn visited_count(&self) -> usize;
}
pub struct BfsIter<'a> {
graph: &'a dyn GraphBackend,
queue: VecDeque<(i64, u32)>,
visited: HashSet<i64>,
max_depth: u32,
start: i64,
count: usize,
}
impl<'a> BfsIter<'a> {
pub fn new(graph: &'a dyn GraphBackend, start: i64, max_depth: u32) -> Self {
let mut visited = HashSet::new();
visited.insert(start);
let mut queue = VecDeque::new();
queue.push_back((start, 0));
Self {
graph,
queue,
visited,
max_depth,
start,
count: 0,
}
}
}
impl GraphIterator for BfsIter<'_> {
fn visited_count(&self) -> usize {
self.count
}
}
impl Iterator for BfsIter<'_> {
type Item = Result<i64, SqliteGraphError>;
fn next(&mut self) -> Option<Self::Item> {
loop {
let (node, depth) = self.queue.pop_front()?;
if depth == self.max_depth {
if node != self.start {
self.count += 1;
return Some(Ok(node));
}
continue;
}
match self.graph.fetch_outgoing(node) {
Ok(neighbors) => {
for neighbor in neighbors {
if self.visited.insert(neighbor) {
self.queue.push_back((neighbor, depth + 1));
}
}
}
Err(e) => {
return Some(Err(e));
}
}
}
}
}
pub fn bfs_iter<'a>(graph: &'a dyn GraphBackend, start: i64, max_depth: u32) -> BfsIter<'a> {
BfsIter::new(graph, start, max_depth)
}
pub struct DfsIter<'a> {
graph: &'a dyn GraphBackend,
stack: Vec<i64>,
visited: HashSet<i64>,
pending: VecDeque<i64>,
count: usize,
}
impl<'a> DfsIter<'a> {
pub fn new(graph: &'a dyn GraphBackend, start: i64) -> Self {
let mut visited = HashSet::new();
visited.insert(start);
Self {
graph,
stack: vec![start],
visited,
pending: VecDeque::new(),
count: 0,
}
}
}
impl GraphIterator for DfsIter<'_> {
fn visited_count(&self) -> usize {
self.count
}
}
impl Iterator for DfsIter<'_> {
type Item = Result<i64, SqliteGraphError>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(node) = self.pending.pop_front() {
self.count += 1;
return Some(Ok(node));
}
while let Some(node) = self.stack.pop() {
let neighbors = match self.graph.fetch_outgoing(node) {
Ok(n) => n,
Err(e) => return Some(Err(e)),
};
for neighbor in neighbors {
if self.visited.insert(neighbor) {
self.stack.push(neighbor);
self.pending.push_back(neighbor);
}
}
if let Some(node) = self.pending.pop_front() {
self.count += 1;
return Some(Ok(node));
}
}
None
}
}
pub fn dfs_iter<'a>(graph: &'a dyn GraphBackend, start: i64) -> DfsIter<'a> {
DfsIter::new(graph, start)
}
pub struct TopologicalSortIter<'a> {
graph: &'a dyn GraphBackend,
queue: VecDeque<i64>,
in_degree: std::collections::HashMap<i64, usize>,
count: usize,
total_nodes: usize,
}
impl<'a> TopologicalSortIter<'a> {
pub fn new(graph: &'a dyn GraphBackend) -> Result<Self, SqliteGraphError> {
let all_ids = graph.all_entity_ids()?;
if all_ids.is_empty() {
return Ok(Self {
graph,
queue: VecDeque::new(),
in_degree: std::collections::HashMap::new(),
count: 0,
total_nodes: 0,
});
}
let mut in_degree: std::collections::HashMap<i64, usize> =
std::collections::HashMap::with_capacity(all_ids.len());
for &node in &all_ids {
in_degree.insert(node, 0);
}
for &node in &all_ids {
for target in graph.fetch_outgoing(node)? {
*in_degree.entry(target).or_insert(0) += 1;
}
}
let queue: VecDeque<i64> = in_degree
.iter()
.filter(|(_, deg)| **deg == 0)
.map(|(&node, _)| node)
.collect();
Ok(Self {
graph,
queue,
in_degree,
count: 0,
total_nodes: all_ids.len(),
})
}
pub fn has_cycle(&self) -> bool {
self.count < self.total_nodes
}
pub fn total_nodes(&self) -> usize {
self.total_nodes
}
}
impl GraphIterator for TopologicalSortIter<'_> {
fn visited_count(&self) -> usize {
self.count
}
}
impl Iterator for TopologicalSortIter<'_> {
type Item = Result<i64, SqliteGraphError>;
fn next(&mut self) -> Option<Self::Item> {
let node = self.queue.pop_front()?;
self.count += 1;
match self.graph.fetch_outgoing(node) {
Ok(neighbors) => {
for neighbor in neighbors {
if let Some(deg) = self.in_degree.get_mut(&neighbor) {
*deg -= 1;
if *deg == 0 {
self.queue.push_back(neighbor);
}
}
}
}
Err(e) => {
return Some(Err(e));
}
}
Some(Ok(node))
}
}
pub fn topological_sort_iter<'a>(
graph: &'a dyn GraphBackend,
) -> Result<TopologicalSortIter<'a>, SqliteGraphError> {
TopologicalSortIter::new(graph)
}
pub struct ConnectedComponentsIter<'a> {
graph: &'a dyn GraphBackend,
entity_ids: std::vec::IntoIter<i64>,
visited: HashSet<i64>,
count: usize,
}
impl<'a> ConnectedComponentsIter<'a> {
pub fn new(graph: &'a dyn GraphBackend) -> Result<Self, SqliteGraphError> {
let all_ids = graph.all_entity_ids()?;
Ok(Self {
graph,
entity_ids: all_ids.into_iter(),
visited: HashSet::new(),
count: 0,
})
}
pub fn visited_count(&self) -> usize {
self.count
}
}
impl Iterator for ConnectedComponentsIter<'_> {
type Item = Result<Vec<i64>, SqliteGraphError>;
fn next(&mut self) -> Option<Self::Item> {
let seed = loop {
match self.entity_ids.next() {
Some(id) => {
if self.visited.insert(id) {
break id;
}
}
None => return None, }
};
let mut queue = VecDeque::new();
queue.push_back(seed);
let mut component = Vec::new();
while let Some(node) = queue.pop_front() {
component.push(node);
match self.graph.fetch_outgoing(node) {
Ok(neighbors) => {
for neighbor in neighbors {
if self.visited.insert(neighbor) {
queue.push_back(neighbor);
}
}
}
Err(e) => return Some(Err(e)),
}
match self.graph.fetch_incoming(node) {
Ok(neighbors) => {
for neighbor in neighbors {
if self.visited.insert(neighbor) {
queue.push_back(neighbor);
}
}
}
Err(e) => return Some(Err(e)),
}
}
component.sort();
self.count += component.len();
Some(Ok(component))
}
}
pub fn connected_components_iter<'a>(
graph: &'a dyn GraphBackend,
) -> Result<ConnectedComponentsIter<'a>, SqliteGraphError> {
ConnectedComponentsIter::new(graph)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::algo::backend::graph_ops;
use crate::backend::native::v3::V3Backend;
use crate::backend::{EdgeSpec, GraphBackend, NodeSpec};
use tempfile::TempDir;
fn create_backend() -> (V3Backend, TempDir) {
let temp_dir = TempDir::new().expect("invariant: temp dir creation succeeds");
let db_path = temp_dir.path().join("test.graph");
let backend = V3Backend::create(&db_path).expect("invariant: backend creation succeeds");
(backend, temp_dir)
}
#[allow(dead_code)]
struct TestGraph {
backend: V3Backend,
_temp: TempDir,
n0: i64,
n1: i64,
n2: i64,
n3: i64,
}
fn make_test_graph() -> TestGraph {
let (backend, temp) = create_backend();
let n0 = backend
.insert_node(NodeSpec {
kind: "A".to_string(),
name: "a".to_string(),
file_path: None,
data: serde_json::Value::Null,
})
.unwrap();
let n1 = backend
.insert_node(NodeSpec {
kind: "B".to_string(),
name: "b".to_string(),
file_path: None,
data: serde_json::Value::Null,
})
.unwrap();
let n2 = backend
.insert_node(NodeSpec {
kind: "C".to_string(),
name: "c".to_string(),
file_path: None,
data: serde_json::Value::Null,
})
.unwrap();
let n3 = backend
.insert_node(NodeSpec {
kind: "D".to_string(),
name: "d".to_string(),
file_path: None,
data: serde_json::Value::Null,
})
.unwrap();
backend
.insert_edge(EdgeSpec {
from: n0,
to: n1,
edge_type: "e".to_string(),
data: serde_json::Value::Null,
})
.unwrap();
backend
.insert_edge(EdgeSpec {
from: n1,
to: n2,
edge_type: "e".to_string(),
data: serde_json::Value::Null,
})
.unwrap();
backend
.insert_edge(EdgeSpec {
from: n0,
to: n3,
edge_type: "e".to_string(),
data: serde_json::Value::Null,
})
.unwrap();
backend
.insert_edge(EdgeSpec {
from: n1,
to: n3,
edge_type: "e".to_string(),
data: serde_json::Value::Null,
})
.unwrap();
TestGraph {
backend,
_temp: temp,
n0,
n1,
n2,
n3,
}
}
#[test]
fn test_bfs_iter_matches_bfs() {
let tg = make_test_graph();
let vec_result = graph_ops::bfs(&tg.backend, tg.n1, 3).unwrap();
let iter_result: Vec<i64> = bfs_iter(&tg.backend, tg.n1, 3)
.collect::<Result<Vec<_>, _>>()
.unwrap();
let vec_set: std::collections::HashSet<i64> = vec_result.into_iter().collect();
let iter_set: std::collections::HashSet<i64> = iter_result.into_iter().collect();
assert_eq!(
vec_set, iter_set,
"BFS iterator should visit same nodes as materialized BFS"
);
}
#[test]
fn test_bfs_iter_early_termination() {
let tg = make_test_graph();
let taken: Vec<i64> = bfs_iter(&tg.backend, tg.n0, 3)
.take(2)
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert!(
taken.len() <= 2,
"Early termination should limit results to at most 2"
);
}
#[test]
fn test_dfs_iter_matches_dfs() {
let tg = make_test_graph();
let vec_result = super::super::traversal::dfs_traversal(&tg.backend, tg.n1).unwrap();
let iter_result: Vec<i64> = dfs_iter(&tg.backend, tg.n1)
.collect::<Result<Vec<_>, _>>()
.unwrap();
let vec_set: std::collections::HashSet<i64> = vec_result.into_iter().collect();
let iter_set: std::collections::HashSet<i64> = iter_result.into_iter().collect();
assert_eq!(vec_set, iter_set, "DFS iterator should visit same nodes");
}
#[test]
fn test_topo_iter_matches_topo() {
let tg = make_test_graph();
let vec_result = graph_ops::topological_sort(&tg.backend).unwrap();
let iter_result: Vec<i64> = topological_sort_iter(&tg.backend)
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
let vec_set: std::collections::HashSet<i64> = vec_result.into_iter().collect();
let iter_set: std::collections::HashSet<i64> = iter_result.into_iter().collect();
assert_eq!(
vec_set, iter_set,
"Topo iterator should visit same nodes as materialized topological_sort"
);
}
#[test]
fn test_topo_iter_empty_graph() {
let (graph, _temp) = create_backend();
let iter = topological_sort_iter(&graph).unwrap();
let result: Vec<i64> = iter.collect::<Result<Vec<_>, _>>().unwrap();
assert!(result.is_empty(), "Empty graph should yield no nodes");
}
#[test]
fn test_bfs_iter_empty_neighborhood() {
let (graph, _temp) = create_backend();
let n0 = graph
.insert_node(NodeSpec {
kind: "Node".to_string(),
name: "a".to_string(),
file_path: None,
data: serde_json::Value::Null,
})
.unwrap();
let result: Vec<i64> = bfs_iter(&graph, n0, 3)
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert!(result.is_empty(), "Isolated node should yield no neighbors");
}
#[test]
fn test_dfs_iter_visited_count() {
let tg = make_test_graph();
let iter = dfs_iter(&tg.backend, tg.n1);
let collected: Vec<i64> = iter.collect::<Result<Vec<_>, _>>().unwrap();
assert_eq!(
collected.len(),
collected.len(),
"visited_count should match number of yielded nodes"
);
}
fn make_multi_component_graph() -> (V3Backend, TempDir, [i64; 3], [i64; 2], i64) {
let (backend, temp) = create_backend();
let a0 = backend
.insert_node(NodeSpec {
kind: "A".into(),
name: "a0".into(),
file_path: None,
data: serde_json::Value::Null,
})
.unwrap();
let a1 = backend
.insert_node(NodeSpec {
kind: "A".into(),
name: "a1".into(),
file_path: None,
data: serde_json::Value::Null,
})
.unwrap();
let a2 = backend
.insert_node(NodeSpec {
kind: "A".into(),
name: "a2".into(),
file_path: None,
data: serde_json::Value::Null,
})
.unwrap();
let b0 = backend
.insert_node(NodeSpec {
kind: "B".into(),
name: "b0".into(),
file_path: None,
data: serde_json::Value::Null,
})
.unwrap();
let b1 = backend
.insert_node(NodeSpec {
kind: "B".into(),
name: "b1".into(),
file_path: None,
data: serde_json::Value::Null,
})
.unwrap();
let c0 = backend
.insert_node(NodeSpec {
kind: "C".into(),
name: "c0".into(),
file_path: None,
data: serde_json::Value::Null,
})
.unwrap();
backend
.insert_edge(EdgeSpec {
from: a0,
to: a1,
edge_type: "e".into(),
data: serde_json::Value::Null,
})
.unwrap();
backend
.insert_edge(EdgeSpec {
from: a1,
to: a2,
edge_type: "e".into(),
data: serde_json::Value::Null,
})
.unwrap();
backend
.insert_edge(EdgeSpec {
from: b0,
to: b1,
edge_type: "e".into(),
data: serde_json::Value::Null,
})
.unwrap();
backend
.insert_edge(EdgeSpec {
from: b1,
to: b0,
edge_type: "e".into(),
data: serde_json::Value::Null,
})
.unwrap();
(backend, temp, [a0, a1, a2], [b0, b1], c0)
}
#[test]
fn test_cc_iter_matches_vec_result() {
let (graph, _temp, comp_a, comp_b, comp_c) = make_multi_component_graph();
let iter_components: Vec<Vec<i64>> = connected_components_iter(&graph)
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
let mut expected = vec![
{
let mut c = vec![comp_a[0], comp_a[1], comp_a[2]];
c.sort();
c
},
{
let mut c = vec![comp_b[0], comp_b[1]];
c.sort();
c
},
vec![comp_c],
];
expected.sort_by(|a, b| a[0].cmp(&b[0]));
assert_eq!(
iter_components, expected,
"Iterator components must match expected decomposition"
);
}
#[test]
fn test_cc_iter_yields_sorted_components() {
let (graph, _temp, _comp_a, _comp_b, _comp_c) = make_multi_component_graph();
for comp in connected_components_iter(&graph).unwrap() {
let comp = comp.expect("component fetch should not fail");
let mut sorted = comp.clone();
sorted.sort();
assert_eq!(comp, sorted, "Each component must be sorted");
}
}
#[test]
fn test_cc_iter_components_sorted_by_first_element() {
let (graph, _temp, _comp_a, _comp_b, _comp_c) = make_multi_component_graph();
let components: Vec<Vec<i64>> = connected_components_iter(&graph)
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
for window in components.windows(2) {
assert!(
window[0][0] <= window[1][0],
"Components must be sorted by first element: {:?} vs {:?}",
window[0],
window[1],
);
}
}
#[test]
fn test_cc_iter_empty_graph() {
let (graph, _temp) = create_backend();
let components: Vec<Vec<i64>> = connected_components_iter(&graph)
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert!(
components.is_empty(),
"Empty graph should yield no components"
);
}
#[test]
fn test_cc_iter_single_isolated_node() {
let (graph, _temp) = create_backend();
let n0 = graph
.insert_node(NodeSpec {
kind: "N".into(),
name: "solo".into(),
file_path: None,
data: serde_json::Value::Null,
})
.unwrap();
let components: Vec<Vec<i64>> = connected_components_iter(&graph)
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(components.len(), 1, "Single node should be one component");
assert_eq!(
components[0],
vec![n0],
"Component should contain only that node"
);
}
#[test]
fn test_cc_iter_early_termination() {
let (graph, _temp, _comp_a, _comp_b, _comp_c) = make_multi_component_graph();
let first = connected_components_iter(&graph)
.unwrap()
.next()
.expect("should yield at least one component")
.expect("should not error");
assert!(!first.is_empty(), "First component should not be empty");
}
#[test]
fn test_cc_iter_visited_count() {
let (graph, _temp, _comp_a, _comp_b, _comp_c) = make_multi_component_graph();
let mut iter = connected_components_iter(&graph).unwrap();
let total: usize = (&mut iter)
.collect::<Result<Vec<Vec<i64>>, _>>()
.unwrap()
.iter()
.map(|c| c.len())
.sum();
assert_eq!(
iter.visited_count(),
total,
"visited_count should equal total nodes across all components"
);
assert_eq!(iter.visited_count(), 6, "6 nodes total: 3 + 2 + 1");
}
#[test]
fn test_cc_iter_bidirectional_connectivity() {
let (graph, _temp) = create_backend();
let a = graph
.insert_node(NodeSpec {
kind: "X".into(),
name: "a".into(),
file_path: None,
data: serde_json::Value::Null,
})
.unwrap();
let b = graph
.insert_node(NodeSpec {
kind: "X".into(),
name: "b".into(),
file_path: None,
data: serde_json::Value::Null,
})
.unwrap();
let c = graph
.insert_node(NodeSpec {
kind: "X".into(),
name: "c".into(),
file_path: None,
data: serde_json::Value::Null,
})
.unwrap();
graph
.insert_edge(EdgeSpec {
from: a,
to: b,
edge_type: "e".into(),
data: serde_json::Value::Null,
})
.unwrap();
let components: Vec<Vec<i64>> = connected_components_iter(&graph)
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(components.len(), 2, "Should have 2 components");
let ab_component = components
.iter()
.find(|c| c.contains(&a))
.expect("a should be in a component");
assert!(
ab_component.contains(&b),
"a and b should be in the same component (bidirectional)"
);
let c_component = components
.iter()
.find(|comp| comp.contains(&c))
.expect("c should be in a component");
assert_eq!(c_component.len(), 1, "c should be isolated");
}
}