use std::collections::VecDeque;
use ahash::AHashSet;
use crate::progress::ProgressCallback;
use crate::{errors::SqliteGraphError, graph::SqliteGraph};
pub fn reachable_from(graph: &SqliteGraph, start: i64) -> Result<AHashSet<i64>, SqliteGraphError> {
let mut visited = AHashSet::new();
let mut queue = VecDeque::new();
visited.insert(start);
queue.push_back(start);
while let Some(node) = queue.pop_front() {
for neighbor in graph.fetch_outgoing(node)? {
if visited.insert(neighbor) {
queue.push_back(neighbor);
}
}
}
Ok(visited)
}
pub fn reachable_from_with_progress<F>(
graph: &SqliteGraph,
start: i64,
progress: &F,
) -> Result<AHashSet<i64>, SqliteGraphError>
where
F: ProgressCallback,
{
let mut visited = AHashSet::new();
let mut queue = VecDeque::new();
let mut nodes_processed = 0;
visited.insert(start);
queue.push_back(start);
while let Some(node) = queue.pop_front() {
nodes_processed += 1;
if nodes_processed % 10 == 0 {
progress.on_progress(
nodes_processed,
None,
&format!("Forward reachability: visited {}", nodes_processed),
);
}
for neighbor in graph.fetch_outgoing(node)? {
if visited.insert(neighbor) {
queue.push_back(neighbor);
}
}
}
progress.on_complete();
Ok(visited)
}
pub fn reverse_reachable_from(
graph: &SqliteGraph,
target: i64,
) -> Result<AHashSet<i64>, SqliteGraphError> {
let mut visited = AHashSet::new();
let mut queue = VecDeque::new();
visited.insert(target);
queue.push_back(target);
while let Some(node) = queue.pop_front() {
for ancestor in graph.fetch_incoming(node)? {
if visited.insert(ancestor) {
queue.push_back(ancestor);
}
}
}
Ok(visited)
}
pub fn reverse_reachable_from_with_progress<F>(
graph: &SqliteGraph,
target: i64,
progress: &F,
) -> Result<AHashSet<i64>, SqliteGraphError>
where
F: ProgressCallback,
{
let mut visited = AHashSet::new();
let mut queue = VecDeque::new();
let mut nodes_processed = 0;
visited.insert(target);
queue.push_back(target);
while let Some(node) = queue.pop_front() {
nodes_processed += 1;
if nodes_processed % 10 == 0 {
progress.on_progress(
nodes_processed,
None,
&format!("Backward reachability: visited {}", nodes_processed),
);
}
for ancestor in graph.fetch_incoming(node)? {
if visited.insert(ancestor) {
queue.push_back(ancestor);
}
}
}
progress.on_complete();
Ok(visited)
}
pub fn can_reach(graph: &SqliteGraph, from: i64, to: i64) -> Result<bool, SqliteGraphError> {
if from == to {
return Ok(true);
}
let mut visited = AHashSet::new();
let mut queue = VecDeque::new();
visited.insert(from);
queue.push_back(from);
while let Some(node) = queue.pop_front() {
for neighbor in graph.fetch_outgoing(node)? {
if neighbor == to {
return Ok(true);
}
if visited.insert(neighbor) {
queue.push_back(neighbor);
}
}
}
Ok(false)
}
pub fn unreachable_from(
graph: &SqliteGraph,
entry: i64,
) -> Result<AHashSet<i64>, SqliteGraphError> {
let all_nodes: AHashSet<i64> = graph.all_entity_ids()?.into_iter().collect();
if !all_nodes.contains(&entry) {
return Ok(all_nodes);
}
let reachable = reachable_from(graph, entry)?;
Ok(all_nodes
.difference(&reachable)
.copied()
.collect::<AHashSet<_>>())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{GraphEdge, GraphEntity};
fn create_linear_chain() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
for i in 0..entity_ids.len().saturating_sub(1) {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[i + 1],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
fn create_diamond() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let edges = vec![(0, 1), (0, 2), (1, 3), (2, 3)];
for (from_idx, to_idx) in edges {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[from_idx],
to_id: entity_ids[to_idx],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
fn create_cycle() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..3 {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let edges = vec![(0, 1), (1, 2), (2, 1)];
for (from_idx, to_idx) in edges {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[from_idx],
to_id: entity_ids[to_idx],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
fn create_disconnected() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let edge1 = GraphEdge {
id: 0,
from_id: entity_ids[0],
to_id: entity_ids[1],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge1).expect("Failed to insert edge");
let edge2 = GraphEdge {
id: 0,
from_id: entity_ids[2],
to_id: entity_ids[3],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge2).expect("Failed to insert edge");
graph
}
#[test]
fn test_reachable_from_empty() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let result = reachable_from(&graph, 999);
assert!(result.is_ok(), "reachable_from failed on empty graph");
let reachable = result.unwrap();
assert_eq!(
reachable.len(),
1,
"Expected only start node in empty graph"
);
assert!(reachable.contains(&999), "Start node should be in result");
}
#[test]
fn test_reachable_from_single() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: "single_node".to_string(),
file_path: Some("single_node.rs".to_string()),
data: serde_json::json!({}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
let entity_ids = graph.list_entity_ids().expect("Failed to get IDs");
let node_id = entity_ids[0];
let result = reachable_from(&graph, node_id);
assert!(result.is_ok(), "reachable_from failed on single node");
let reachable = result.unwrap();
assert_eq!(reachable.len(), 1, "Expected 1 node reachable");
assert!(reachable.contains(&node_id), "Node should reach itself");
}
#[test]
fn test_reachable_from_linear() {
let graph = create_linear_chain();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let reachable_0 = reachable_from(&graph, entity_ids[0]).expect("Failed");
assert_eq!(
reachable_0.len(),
4,
"Node 0 should reach all 4 nodes in chain"
);
for &id in &entity_ids {
assert!(reachable_0.contains(&id), "Node 0 should reach node {}", id);
}
let reachable_3 = reachable_from(&graph, entity_ids[3]).expect("Failed");
assert_eq!(reachable_3.len(), 1, "Node 3 should reach only itself");
assert!(
reachable_3.contains(&entity_ids[3]),
"Node 3 should reach itself"
);
}
#[test]
fn test_reachable_from_diamond() {
let graph = create_diamond();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let reachable_0 = reachable_from(&graph, entity_ids[0]).expect("Failed");
assert_eq!(
reachable_0.len(),
4,
"Node 0 should reach all 4 nodes in diamond"
);
for &id in &entity_ids {
assert!(reachable_0.contains(&id), "Node 0 should reach node {}", id);
}
}
#[test]
fn test_reachable_from_cycle() {
let graph = create_cycle();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let node_0 = entity_ids[0];
let node_1 = entity_ids[1];
let node_2 = entity_ids[2];
let reachable_0 = reachable_from(&graph, node_0).expect("Failed");
assert_eq!(reachable_0.len(), 3, "Node 0 should reach all 3 nodes");
let reachable_1 = reachable_from(&graph, node_1).expect("Failed");
assert_eq!(
reachable_1.len(),
2,
"Node 1 should reach 2 nodes (1 and 2)"
);
assert!(reachable_1.contains(&node_1), "Node 1 should reach itself");
assert!(reachable_1.contains(&node_2), "Node 1 should reach node 2");
let reachable_2 = reachable_from(&graph, node_2).expect("Failed");
assert_eq!(
reachable_2.len(),
2,
"Node 2 should reach 2 nodes (1 and 2)"
);
assert!(reachable_2.contains(&node_1), "Node 2 should reach node 1");
assert!(reachable_2.contains(&node_2), "Node 2 should reach itself");
}
#[test]
fn test_reachable_from_disconnected() {
let graph = create_disconnected();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let reachable_0 = reachable_from(&graph, entity_ids[0]).expect("Failed");
assert_eq!(reachable_0.len(), 2, "Node 0 should reach 2 nodes");
assert!(
reachable_0.contains(&entity_ids[0]),
"Node 0 should reach itself"
);
assert!(
reachable_0.contains(&entity_ids[1]),
"Node 0 should reach node 1"
);
assert!(
!reachable_0.contains(&entity_ids[2]),
"Node 0 should NOT reach node 2"
);
assert!(
!reachable_0.contains(&entity_ids[3]),
"Node 0 should NOT reach node 3"
);
}
#[test]
fn test_reachable_from_with_progress() {
use crate::progress::NoProgress;
let graph = create_linear_chain();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let progress = NoProgress;
let result_with =
reachable_from_with_progress(&graph, entity_ids[0], &progress).expect("Failed");
let result_without = reachable_from(&graph, entity_ids[0]).expect("Failed");
assert_eq!(
result_with.len(),
result_without.len(),
"Progress and non-progress results should match"
);
for &id in &result_with {
assert!(
result_without.contains(&id),
"Progress result contains node not in non-progress result"
);
}
}
#[test]
fn test_reverse_reachable_from_empty() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let result = reverse_reachable_from(&graph, 999);
assert!(
result.is_ok(),
"reverse_reachable_from failed on empty graph"
);
let reachable = result.unwrap();
assert_eq!(
reachable.len(),
1,
"Expected only target node in empty graph"
);
assert!(reachable.contains(&999), "Target node should be in result");
}
#[test]
fn test_reverse_reachable_from_single() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: "single_node".to_string(),
file_path: Some("single_node.rs".to_string()),
data: serde_json::json!({}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
let entity_ids = graph.list_entity_ids().expect("Failed to get IDs");
let node_id = entity_ids[0];
let result = reverse_reachable_from(&graph, node_id);
assert!(
result.is_ok(),
"reverse_reachable_from failed on single node"
);
let reachable = result.unwrap();
assert_eq!(reachable.len(), 1, "Expected 1 node reachable");
assert!(reachable.contains(&node_id), "Node should reach itself");
}
#[test]
fn test_reverse_reachable_from_linear() {
let graph = create_linear_chain();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let reverse_3 = reverse_reachable_from(&graph, entity_ids[3]).expect("Failed");
assert_eq!(
reverse_3.len(),
4,
"Node 3 should be reachable from all 4 nodes in chain"
);
for &id in &entity_ids {
assert!(
reverse_3.contains(&id),
"Node {} should be able to reach node 3",
id
);
}
let reverse_0 = reverse_reachable_from(&graph, entity_ids[0]).expect("Failed");
assert_eq!(reverse_0.len(), 1, "Node 0 should only reach itself");
assert!(
reverse_0.contains(&entity_ids[0]),
"Node 0 should reach itself"
);
}
#[test]
fn test_reverse_reachable_from_diamond() {
let graph = create_diamond();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let reverse_3 = reverse_reachable_from(&graph, entity_ids[3]).expect("Failed");
assert_eq!(
reverse_3.len(),
4,
"Node 3 should be reachable from all 4 nodes in diamond"
);
for &id in &entity_ids {
assert!(
reverse_3.contains(&id),
"Node {} should be able to reach node 3",
id
);
}
}
#[test]
fn test_reverse_reachable_from_cycle() {
let graph = create_cycle();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let node_0 = entity_ids[0];
let node_1 = entity_ids[1];
let node_2 = entity_ids[2];
let reverse_1 = reverse_reachable_from(&graph, node_1).expect("Failed");
assert_eq!(
reverse_1.len(),
3,
"Node 1 should be reachable from all 3 nodes"
);
assert!(reverse_1.contains(&node_0), "Node 0 should reach node 1");
assert!(reverse_1.contains(&node_1), "Node 1 should reach itself");
assert!(reverse_1.contains(&node_2), "Node 2 should reach node 1");
let reverse_2 = reverse_reachable_from(&graph, node_2).expect("Failed");
assert_eq!(
reverse_2.len(),
3,
"Node 2 should be reachable from 3 nodes"
);
assert!(reverse_2.contains(&node_0), "Node 0 should reach node 2");
assert!(reverse_2.contains(&node_1), "Node 1 should reach node 2");
assert!(reverse_2.contains(&node_2), "Node 2 should reach itself");
}
#[test]
fn test_reverse_reachable_from_with_progress() {
use crate::progress::NoProgress;
let graph = create_linear_chain();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let progress = NoProgress;
let result_with =
reverse_reachable_from_with_progress(&graph, entity_ids[3], &progress).expect("Failed");
let result_without = reverse_reachable_from(&graph, entity_ids[3]).expect("Failed");
assert_eq!(
result_with.len(),
result_without.len(),
"Progress and non-progress results should match"
);
for &id in &result_with {
assert!(
result_without.contains(&id),
"Progress result contains node not in non-progress result"
);
}
}
#[test]
fn test_can_reach_self() {
let graph = create_linear_chain();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
for &node_id in &entity_ids {
let result = can_reach(&graph, node_id, node_id).expect("Failed");
assert!(result, "Node {} should be able to reach itself", node_id);
}
}
#[test]
fn test_can_reach_linear() {
let graph = create_linear_chain();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
assert!(
can_reach(&graph, entity_ids[0], entity_ids[0]).expect("Failed"),
"Node 0 should reach itself"
);
assert!(
can_reach(&graph, entity_ids[0], entity_ids[1]).expect("Failed"),
"Node 0 should reach node 1"
);
assert!(
can_reach(&graph, entity_ids[0], entity_ids[2]).expect("Failed"),
"Node 0 should reach node 2"
);
assert!(
can_reach(&graph, entity_ids[0], entity_ids[3]).expect("Failed"),
"Node 0 should reach node 3"
);
assert!(
can_reach(&graph, entity_ids[3], entity_ids[3]).expect("Failed"),
"Node 3 should reach itself"
);
assert!(
!can_reach(&graph, entity_ids[3], entity_ids[0]).expect("Failed"),
"Node 3 should NOT reach node 0"
);
}
#[test]
fn test_can_reach_cycle() {
let graph = create_cycle();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let _node_0 = entity_ids[0];
let node_1 = entity_ids[1];
let node_2 = entity_ids[2];
assert!(
can_reach(&graph, node_1, node_2).expect("Failed"),
"Node 1 should reach node 2 (in cycle)"
);
assert!(
can_reach(&graph, node_2, node_1).expect("Failed"),
"Node 2 should reach node 1 (in cycle)"
);
}
#[test]
fn test_can_reach_disconnected() {
let graph = create_disconnected();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
assert!(
!can_reach(&graph, entity_ids[0], entity_ids[2]).expect("Failed"),
"Node 0 should NOT reach node 2 (disconnected)"
);
assert!(
!can_reach(&graph, entity_ids[0], entity_ids[3]).expect("Failed"),
"Node 0 should NOT reach node 3 (disconnected)"
);
}
#[test]
fn test_can_reach_nonexistent() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let result = can_reach(&graph, 999, 888);
assert!(
result.is_ok(),
"can_reach should not error on non-existent nodes"
);
assert!(
!result.unwrap(),
"Non-existent nodes should not reach each other"
);
}
#[test]
fn test_unreachable_from_empty() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let result = unreachable_from(&graph, 0);
assert!(result.is_ok(), "unreachable_from failed on empty graph");
let unreachable = result.unwrap();
assert_eq!(
unreachable.len(),
0,
"Expected 0 unreachable nodes in empty graph"
);
}
#[test]
fn test_unreachable_from_linear() {
let graph = create_linear_chain();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let unreachable = unreachable_from(&graph, entity_ids[0]).expect("Failed");
assert_eq!(
unreachable.len(),
0,
"Expected 0 unreachable nodes in fully connected chain"
);
}
#[test]
fn test_unreachable_from_disconnected() {
let graph = create_disconnected();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let unreachable = unreachable_from(&graph, entity_ids[0]).expect("Failed");
assert_eq!(unreachable.len(), 2, "Expected 2 unreachable nodes");
assert!(
unreachable.contains(&entity_ids[2]),
"Node 2 should be unreachable from node 0"
);
assert!(
unreachable.contains(&entity_ids[3]),
"Node 3 should be unreachable from node 0"
);
assert!(
!unreachable.contains(&entity_ids[0]),
"Node 0 should not be unreachable from itself"
);
assert!(
!unreachable.contains(&entity_ids[1]),
"Node 1 should be reachable from node 0"
);
}
#[test]
fn test_unreachable_from_diamond() {
let graph = create_diamond();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let unreachable = unreachable_from(&graph, entity_ids[0]).expect("Failed");
assert_eq!(
unreachable.len(),
0,
"Expected 0 unreachable nodes in diamond (all reachable from 0)"
);
}
#[test]
fn test_unreachable_from_nonexistent_entry() {
let graph = create_linear_chain();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let unreachable = unreachable_from(&graph, 999).expect("Failed");
assert_eq!(
unreachable.len(),
4,
"Expected all 4 nodes to be unreachable from non-existent entry"
);
for &id in &entity_ids {
assert!(
unreachable.contains(&id),
"Node {} should be unreachable from non-existent entry",
id
);
}
}
}