use std::collections::VecDeque;
use ahash::AHashMap;
use crate::progress::ProgressCallback;
use crate::{errors::SqliteGraphError, graph::SqliteGraph};
#[derive(Debug, Clone, Default)]
pub struct TransitiveClosureBounds {
pub max_depth: Option<usize>,
pub max_sources: Option<usize>,
pub max_pairs: Option<usize>,
}
impl TransitiveClosureBounds {
#[inline]
pub fn unbounded() -> Self {
Self::default()
}
#[inline]
pub fn with_depth(max_depth: usize) -> Self {
Self {
max_depth: Some(max_depth),
max_sources: None,
max_pairs: None,
}
}
#[inline]
pub fn with_sources(max_sources: usize) -> Self {
Self {
max_depth: None,
max_sources: Some(max_sources),
max_pairs: None,
}
}
}
pub fn transitive_closure(
graph: &SqliteGraph,
bounds: Option<TransitiveClosureBounds>,
) -> Result<AHashMap<(i64, i64), bool>, SqliteGraphError> {
let all_ids = graph.all_entity_ids()?;
let n = all_ids.len();
if n == 0 {
return Ok(AHashMap::new());
}
let bounds = bounds.unwrap_or_default();
let max_depth = bounds.max_depth;
let max_sources = bounds.max_sources.unwrap_or(n);
let max_pairs = bounds.max_pairs;
let mut closure = AHashMap::new();
let sources: Vec<i64> = all_ids.into_iter().take(max_sources).collect();
for &source in sources.iter() {
closure.insert((source, source), true);
let mut visited = ahash::AHashSet::new();
let mut queue = VecDeque::new();
visited.insert(source);
queue.push_back((source, 0));
while let Some((node, depth)) = queue.pop_front() {
if let Some(max_d) = max_depth {
if depth >= max_d {
continue;
}
}
for &neighbor in &graph.fetch_outgoing(node)? {
if visited.insert(neighbor) {
closure.insert((source, neighbor), true);
if let Some(max_p) = max_pairs {
if closure.len() >= max_p {
return Ok(closure);
}
}
if max_depth.is_none() || depth + 1 < max_depth.unwrap() {
queue.push_back((neighbor, depth + 1));
}
}
}
}
}
Ok(closure)
}
pub fn transitive_closure_with_progress<F>(
graph: &SqliteGraph,
bounds: Option<TransitiveClosureBounds>,
progress: &F,
) -> Result<AHashMap<(i64, i64), bool>, SqliteGraphError>
where
F: ProgressCallback,
{
let all_ids = graph.all_entity_ids()?;
let n = all_ids.len();
if n == 0 {
progress.on_complete();
return Ok(AHashMap::new());
}
let bounds = bounds.unwrap_or_default();
let max_depth = bounds.max_depth;
let max_sources = bounds.max_sources.unwrap_or(n);
let max_pairs = bounds.max_pairs;
let mut closure = AHashMap::new();
let sources: Vec<i64> = all_ids.into_iter().take(max_sources).collect();
for (source_idx, &source) in sources.iter().enumerate() {
progress.on_progress(
source_idx + 1,
Some(sources.len()),
&format!(
"Transitive closure: source {}/{}",
source_idx + 1,
sources.len()
),
);
closure.insert((source, source), true);
let mut visited = ahash::AHashSet::new();
let mut queue = VecDeque::new();
visited.insert(source);
queue.push_back((source, 0));
while let Some((node, depth)) = queue.pop_front() {
if let Some(max_d) = max_depth {
if depth >= max_d {
continue;
}
}
for &neighbor in &graph.fetch_outgoing(node)? {
if visited.insert(neighbor) {
closure.insert((source, neighbor), true);
if let Some(max_p) = max_pairs {
if closure.len() >= max_p {
progress.on_complete();
return Ok(closure);
}
}
if max_depth.is_none() || depth + 1 < max_depth.unwrap() {
queue.push_back((neighbor, depth + 1));
}
}
}
}
}
progress.on_complete();
Ok(closure)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{GraphEdge, GraphEntity};
fn create_linear_graph() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = GraphEntity {
id: 0,
kind: "test".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("test_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
for i in 0..entity_ids.len().saturating_sub(1) {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[i + 1],
edge_type: "connects".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
fn create_cycle_graph() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..3 {
let entity = GraphEntity {
id: 0,
kind: "test".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("test_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let edges = vec![(0, 1), (1, 2), (2, 1)];
for (from_idx, to_idx) in edges {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[from_idx],
to_id: entity_ids[to_idx],
edge_type: "connects".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
#[test]
fn test_transitive_closure_empty() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let result = transitive_closure(&graph, None);
assert!(result.is_ok(), "transitive_closure failed");
let closure = result.unwrap();
assert_eq!(closure.len(), 0, "Expected empty closure for empty graph");
}
#[test]
fn test_transitive_closure_single_node() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let entity = GraphEntity {
id: 0,
kind: "test".to_string(),
name: "single_node".to_string(),
file_path: Some("test.rs".to_string()),
data: serde_json::json!({}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
let entity_ids = graph.list_entity_ids().expect("Failed to get IDs");
let node_id = entity_ids[0];
let result = transitive_closure(&graph, None);
assert!(result.is_ok(), "transitive_closure failed");
let closure = result.unwrap();
assert_eq!(closure.len(), 1, "Expected 1 reachable pair");
assert_eq!(
closure.get(&(node_id, node_id)),
Some(&true),
"Node should reach itself"
);
}
#[test]
fn test_transitive_closure_linear_chain() {
let graph = create_linear_graph();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let result = transitive_closure(&graph, None);
assert!(result.is_ok(), "transitive_closure failed");
let closure = result.unwrap();
for (i, &from) in entity_ids.iter().enumerate() {
for (j, &to) in entity_ids.iter().enumerate() {
let can_reach = closure.get(&(from, to)).copied().unwrap_or(false);
assert_eq!(
can_reach,
i <= j,
"Node {} ({}) should {} reach node {} ({})",
i,
from,
if i <= j {
"be able to"
} else {
"NOT be able to"
},
j,
to
);
}
}
for &node_id in &entity_ids {
assert_eq!(
closure.get(&(node_id, node_id)),
Some(&true),
"Node {} should reach itself",
node_id
);
}
}
#[test]
fn test_transitive_closure_cycle() {
let graph = create_cycle_graph();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let node_0 = entity_ids[0];
let node_1 = entity_ids[1];
let node_2 = entity_ids[2];
let result = transitive_closure(&graph, None);
assert!(result.is_ok(), "transitive_closure failed");
let closure = result.unwrap();
assert_eq!(
closure.get(&(node_0, node_0)),
Some(&true),
"Node 0 should reach itself"
);
assert_eq!(
closure.get(&(node_0, node_1)),
Some(&true),
"Node 0 should reach node 1"
);
assert_eq!(
closure.get(&(node_0, node_2)),
Some(&true),
"Node 0 should reach node 2"
);
assert_eq!(
closure.get(&(node_1, node_1)),
Some(&true),
"Node 1 should reach itself"
);
assert_eq!(
closure.get(&(node_1, node_2)),
Some(&true),
"Node 1 should reach node 2"
);
assert_eq!(
closure.get(&(node_2, node_1)),
Some(&true),
"Node 2 should reach node 1"
);
assert_eq!(
closure.get(&(node_2, node_2)),
Some(&true),
"Node 2 should reach itself"
);
assert_eq!(
closure.get(&(node_1, node_0)),
None,
"Node 1 should NOT reach node 0"
);
assert_eq!(
closure.get(&(node_2, node_0)),
None,
"Node 2 should NOT reach node 0"
);
}
#[test]
fn test_transitive_closure_bounded_depth() {
let graph = create_linear_graph();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let node_0 = entity_ids[0];
let node_1 = entity_ids[1];
let node_2 = entity_ids[2];
let node_3 = entity_ids[3];
let bounds = TransitiveClosureBounds {
max_depth: Some(2),
max_sources: None,
max_pairs: None,
};
let result = transitive_closure(&graph, Some(bounds));
assert!(result.is_ok(), "transitive_closure failed");
let closure = result.unwrap();
assert_eq!(
closure.get(&(node_0, node_0)),
Some(&true),
"Node 0 should reach itself"
);
assert_eq!(
closure.get(&(node_0, node_1)),
Some(&true),
"Node 0 should reach node 1"
);
assert_eq!(
closure.get(&(node_0, node_2)),
Some(&true),
"Node 0 should reach node 2"
);
assert_eq!(
closure.get(&(node_0, node_3)),
None,
"Node 0 should NOT reach node 3 (depth limit)"
);
}
#[test]
fn test_transitive_closure_bounded_pairs() {
let graph = create_linear_graph();
let bounds = TransitiveClosureBounds {
max_depth: None,
max_sources: None,
max_pairs: Some(5),
};
let result = transitive_closure(&graph, Some(bounds));
assert!(result.is_ok(), "transitive_closure failed");
let closure = result.unwrap();
assert!(
closure.len() <= 6,
"Should stop at approximately 5 pairs, got {}",
closure.len()
);
}
#[test]
fn test_transitive_closure_bounded_sources() {
let graph = create_linear_graph();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let bounds = TransitiveClosureBounds {
max_depth: None,
max_sources: Some(2),
max_pairs: None,
};
let result = transitive_closure(&graph, Some(bounds));
assert!(result.is_ok(), "transitive_closure failed");
let closure = result.unwrap();
let source_0 = entity_ids[0];
let source_1 = entity_ids[1];
let source_2 = entity_ids[2];
assert!(
closure.keys().any(|&(src, _)| src == source_0),
"Source 0 should have reachability entries"
);
assert!(
closure.keys().any(|&(src, _)| src == source_1),
"Source 1 should have reachability entries"
);
assert!(
!closure.keys().any(|&(src, _)| src == source_2),
"Source 2 should NOT have reachability entries (source limit)"
);
}
#[test]
fn test_transitive_closure_bounds_default() {
let graph = create_linear_graph();
let result_none = transitive_closure(&graph, None);
let result_default = transitive_closure(&graph, Some(TransitiveClosureBounds::default()));
assert!(result_none.is_ok(), "transitive_closure with None failed");
assert!(
result_default.is_ok(),
"transitive_closure with default failed"
);
let closure_none = result_none.unwrap();
let closure_default = result_default.unwrap();
assert_eq!(
closure_none.len(),
closure_default.len(),
"Default bounds should match None"
);
}
#[test]
fn test_transitive_closure_with_progress() {
use crate::progress::NoProgress;
let graph = create_linear_graph();
let progress = NoProgress;
let result = transitive_closure_with_progress(&graph, None, &progress);
assert!(result.is_ok(), "transitive_closure_with_progress failed");
let closure = result.unwrap();
assert!(closure.len() > 0, "Should have reachable pairs");
}
#[test]
fn test_transitive_closure_self_reachability() {
let graph = create_cycle_graph();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let result = transitive_closure(&graph, None);
assert!(result.is_ok(), "transitive_closure failed");
let closure = result.unwrap();
for &node_id in &entity_ids {
assert_eq!(
closure.get(&(node_id, node_id)),
Some(&true),
"Node {} should reach itself",
node_id
);
}
}
}