use ahash::{AHashMap, AHashSet};
use crate::errors::SqliteGraphError;
use crate::graph::SqliteGraph;
use crate::progress::ProgressCallback;
use super::scc::strongly_connected_components;
#[derive(Debug, Clone)]
pub struct SccCollapseResult {
pub node_to_supernode: AHashMap<i64, i64>,
pub supernode_members: AHashMap<i64, AHashSet<i64>>,
pub supernode_edges: Vec<(i64, i64)>,
pub num_sccs: usize,
}
impl SccCollapseResult {
pub fn supernode_for(&self, node: i64) -> Option<i64> {
self.node_to_supernode.get(&node).copied()
}
pub fn members_of(&self, supernode: i64) -> Option<&AHashSet<i64>> {
self.supernode_members.get(&supernode)
}
pub fn is_trivial(&self, supernode: i64) -> bool {
match self.members_of(supernode) {
Some(members) => members.len() == 1,
None => false,
}
}
pub fn non_trivial_count(&self) -> usize {
self.supernode_members
.values()
.filter(|members| members.len() > 1)
.count()
}
pub fn non_trivial_nodes(&self) -> AHashSet<i64> {
self.supernode_members
.values()
.filter(|members| members.len() > 1)
.flat_map(|members| members.iter().copied())
.collect()
}
}
pub fn collapse_sccs(graph: &SqliteGraph) -> Result<SccCollapseResult, SqliteGraphError> {
let scc_result = strongly_connected_components(graph)?;
if scc_result.components.is_empty() {
return Ok(SccCollapseResult {
node_to_supernode: AHashMap::new(),
supernode_members: AHashMap::new(),
supernode_edges: Vec::new(),
num_sccs: 0,
});
}
let mut node_to_supernode: AHashMap<i64, i64> = AHashMap::new();
let mut supernode_members: AHashMap<i64, AHashSet<i64>> = AHashMap::new();
for (&node, &component_idx) in &scc_result.node_to_component {
let supernode_id = component_idx as i64;
node_to_supernode.insert(node, supernode_id);
supernode_members
.entry(supernode_id)
.or_default()
.insert(node);
}
let mut edge_set: AHashSet<(i64, i64)> = AHashSet::new();
for &from_node in &graph.all_entity_ids()? {
if let Some(&from_supernode) = node_to_supernode.get(&from_node) {
for &to_node in &graph.fetch_outgoing(from_node)? {
if let Some(&to_supernode) = node_to_supernode.get(&to_node) {
if from_supernode != to_supernode {
edge_set.insert((from_supernode, to_supernode));
}
}
}
}
}
let mut supernode_edges: Vec<(i64, i64)> = edge_set.into_iter().collect();
supernode_edges.sort();
supernode_edges.dedup();
Ok(SccCollapseResult {
node_to_supernode,
supernode_members,
supernode_edges,
num_sccs: scc_result.components.len(),
})
}
pub fn collapse_sccs_with_progress<F>(
graph: &SqliteGraph,
progress: &F,
) -> Result<SccCollapseResult, SqliteGraphError>
where
F: ProgressCallback,
{
let scc_result = strongly_connected_components(graph)?;
progress.on_progress(
0,
None,
&format!(
"SCC collapse: computed {} SCCs",
scc_result.components.len()
),
);
if scc_result.components.is_empty() {
progress.on_complete();
return Ok(SccCollapseResult {
node_to_supernode: AHashMap::new(),
supernode_members: AHashMap::new(),
supernode_edges: Vec::new(),
num_sccs: 0,
});
}
let mut node_to_supernode: AHashMap<i64, i64> = AHashMap::new();
let mut supernode_members: AHashMap<i64, AHashSet<i64>> = AHashMap::new();
for (&node, &component_idx) in &scc_result.node_to_component {
let supernode_id = component_idx as i64;
node_to_supernode.insert(node, supernode_id);
supernode_members
.entry(supernode_id)
.or_default()
.insert(node);
}
progress.on_progress(0, None, "SCC collapse: building condensed graph...");
let mut edge_set: AHashSet<(i64, i64)> = AHashSet::new();
let all_nodes = graph.all_entity_ids()?;
let total_edges_hint = all_nodes.len().saturating_mul(2);
let mut edges_processed = 0;
for &from_node in &all_nodes {
if let Some(&from_supernode) = node_to_supernode.get(&from_node) {
for &to_node in &graph.fetch_outgoing(from_node)? {
edges_processed += 1;
if let Some(&to_supernode) = node_to_supernode.get(&to_node) {
if from_supernode != to_supernode {
edge_set.insert((from_supernode, to_supernode));
}
}
}
}
if edges_processed % 100 == 0 {
progress.on_progress(
edges_processed,
Some(total_edges_hint),
&format!(
"SCC collapse: processed {} edges, {} unique supernode edges",
edges_processed,
edge_set.len()
),
);
}
}
let mut supernode_edges: Vec<(i64, i64)> = edge_set.into_iter().collect();
supernode_edges.sort();
supernode_edges.dedup();
progress.on_complete();
Ok(SccCollapseResult {
node_to_supernode,
supernode_members,
supernode_edges,
num_sccs: scc_result.components.len(),
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{GraphEdge, GraphEntity};
fn create_empty_graph() -> SqliteGraph {
SqliteGraph::open_in_memory().expect("Failed to create graph")
}
fn create_single_node_graph() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: "single".to_string(),
file_path: Some("single.rs".to_string()),
data: serde_json::json!({}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
graph
}
fn create_dag() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.all_entity_ids().expect("Failed to get IDs");
for i in 0..entity_ids.len().saturating_sub(1) {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[i + 1],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
fn create_mutual_recursion_graph() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..5 {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.all_entity_ids().expect("Failed to get IDs");
let edges = vec![(0, 1), (1, 0), (2, 3), (3, 4)];
for (from_idx, to_idx) in edges {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[from_idx],
to_id: entity_ids[to_idx],
edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
fn create_triangle_scc() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..3 {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.all_entity_ids().expect("Failed to get IDs");
let cycle = vec![(0, 1), (1, 2), (2, 0)];
for (from_idx, to_idx) in cycle {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[from_idx],
to_id: entity_ids[to_idx],
edge_type: "cycle".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
#[test]
fn test_collapse_sccs_empty_graph() {
let graph = create_empty_graph();
let result = collapse_sccs(&graph);
assert!(result.is_ok());
let collapsed = result.unwrap();
assert_eq!(collapsed.num_sccs, 0);
assert_eq!(collapsed.node_to_supernode.len(), 0);
assert_eq!(collapsed.supernode_members.len(), 0);
assert_eq!(collapsed.supernode_edges.len(), 0);
assert_eq!(collapsed.non_trivial_count(), 0);
}
#[test]
fn test_collapse_sccs_single_node() {
let graph = create_single_node_graph();
let result = collapse_sccs(&graph);
assert!(result.is_ok());
let collapsed = result.unwrap();
assert_eq!(collapsed.num_sccs, 1);
assert_eq!(collapsed.node_to_supernode.len(), 1);
assert_eq!(collapsed.supernode_members.len(), 1);
assert_eq!(collapsed.supernode_edges.len(), 0);
assert_eq!(collapsed.non_trivial_count(), 0);
let entity_ids: Vec<i64> = graph.all_entity_ids().expect("Failed to get IDs");
let node_id = entity_ids[0];
let supernode = collapsed.supernode_for(node_id);
assert!(supernode.is_some(), "Node should have a supernode");
let members = collapsed.members_of(supernode.unwrap());
assert!(members.is_some(), "Supernode should have members");
assert_eq!(members.unwrap().len(), 1, "SCC should have 1 member");
assert!(
collapsed.is_trivial(supernode.unwrap()),
"Single node should be trivial"
);
}
#[test]
fn test_collapse_sccs_dag() {
let graph = create_dag();
let result = collapse_sccs(&graph);
assert!(result.is_ok());
let collapsed = result.unwrap();
assert_eq!(collapsed.num_sccs, 4);
assert_eq!(collapsed.node_to_supernode.len(), 4);
assert_eq!(collapsed.non_trivial_count(), 0);
assert_eq!(collapsed.supernode_edges.len(), 3);
for &(from, to) in &collapsed.supernode_edges {
assert_ne!(from, to, "Condensed graph should have no self-loops");
}
for (&supernode, members) in &collapsed.supernode_members {
assert!(
collapsed.is_trivial(supernode),
"DAG nodes should be trivial SCCs"
);
assert_eq!(members.len(), 1, "Each SCC should have 1 member");
}
}
#[test]
fn test_collapse_sccs_mutual_recursion() {
let graph = create_mutual_recursion_graph();
let result = collapse_sccs(&graph);
assert!(result.is_ok());
let collapsed = result.unwrap();
assert_eq!(collapsed.num_sccs, 4);
assert_eq!(collapsed.node_to_supernode.len(), 5);
assert_eq!(collapsed.non_trivial_count(), 1);
let entity_ids: Vec<i64> = graph.all_entity_ids().expect("Failed to get IDs");
let scc0 = collapsed.supernode_for(entity_ids[0]);
let scc1 = collapsed.supernode_for(entity_ids[1]);
assert_eq!(scc0, scc1, "Nodes 0 and 1 should be in same SCC");
let scc2 = collapsed.supernode_for(entity_ids[2]);
let scc3 = collapsed.supernode_for(entity_ids[3]);
let scc4 = collapsed.supernode_for(entity_ids[4]);
assert_ne!(scc2, scc0, "Node 2 should not be in SCC 0/1");
assert_ne!(scc3, scc2, "Node 3 should not be in SCC 2");
assert_ne!(scc4, scc3, "Node 4 should not be in SCC 3");
if let Some(scc_id) = scc0 {
assert!(
!collapsed.is_trivial(scc_id),
"Mutual recursion SCC should be non-trivial"
);
if let Some(members) = collapsed.members_of(scc_id) {
assert_eq!(
members.len(),
2,
"Mutual recursion SCC should have 2 members"
);
}
}
for &(from, to) in &collapsed.supernode_edges {
assert_ne!(from, to, "Condensed graph should have no self-loops");
}
}
#[test]
fn test_collapse_sccs_triangle() {
let graph = create_triangle_scc();
let result = collapse_sccs(&graph);
assert!(result.is_ok());
let collapsed = result.unwrap();
assert_eq!(collapsed.num_sccs, 1);
assert_eq!(collapsed.node_to_supernode.len(), 3);
assert_eq!(collapsed.non_trivial_count(), 1);
let entity_ids: Vec<i64> = graph.all_entity_ids().expect("Failed to get IDs");
let scc0 = collapsed.supernode_for(entity_ids[0]);
let scc1 = collapsed.supernode_for(entity_ids[1]);
let scc2 = collapsed.supernode_for(entity_ids[2]);
assert_eq!(scc0, scc1, "Nodes 0 and 1 should be in same SCC");
assert_eq!(scc1, scc2, "Nodes 1 and 2 should be in same SCC");
assert_eq!(collapsed.supernode_edges.len(), 0);
if let Some(scc_id) = scc0 {
assert!(
!collapsed.is_trivial(scc_id),
"Triangle SCC should be non-trivial"
);
if let Some(members) = collapsed.members_of(scc_id) {
assert_eq!(members.len(), 3, "Triangle SCC should have 3 members");
}
}
}
#[test]
fn test_collapse_sccs_no_self_loops() {
let graph = create_mutual_recursion_graph();
let result = collapse_sccs(&graph);
assert!(result.is_ok());
let collapsed = result.unwrap();
for &(from, to) in &collapsed.supernode_edges {
assert_ne!(from, to, "Condensed DAG should not have self-loops");
}
}
#[test]
fn test_collapse_sccs_bidirectional_mapping() {
let graph = create_mutual_recursion_graph();
let result = collapse_sccs(&graph);
assert!(result.is_ok());
let collapsed = result.unwrap();
let entity_ids: Vec<i64> = graph.all_entity_ids().expect("Failed to get IDs");
for &node in &entity_ids {
if let Some(supernode) = collapsed.supernode_for(node) {
if let Some(members) = collapsed.members_of(supernode) {
assert!(
members.contains(&node),
"Supernode {} should contain node {}",
supernode,
node
);
}
}
}
for (&supernode, members) in &collapsed.supernode_members {
for &member in members {
let mapped = collapsed.supernode_for(member);
assert_eq!(
Some(supernode),
mapped,
"Node {} should map back to supernode {}",
member,
supernode
);
}
}
}
#[test]
fn test_collapse_sccs_deterministic_edges() {
let graph = create_dag();
let result = collapse_sccs(&graph);
assert!(result.is_ok());
let collapsed = result.unwrap();
let mut edges_copy = collapsed.supernode_edges.clone();
edges_copy.sort();
assert_eq!(
collapsed.supernode_edges, edges_copy,
"Edges should be sorted"
);
let mut unique_edges = collapsed.supernode_edges.clone();
unique_edges.dedup();
assert_eq!(
collapsed.supernode_edges.len(),
unique_edges.len(),
"Edges should be deduplicated"
);
}
#[test]
fn test_supernode_for() {
let graph = create_mutual_recursion_graph();
let collapsed = collapse_sccs(&graph).expect("Failed");
let entity_ids: Vec<i64> = graph.all_entity_ids().expect("Failed to get IDs");
let supernode = collapsed.supernode_for(entity_ids[0]);
assert!(supernode.is_some(), "Existing node should have supernode");
let non_existent = collapsed.supernode_for(99999);
assert!(
non_existent.is_none(),
"Non-existent node should return None"
);
}
#[test]
fn test_members_of() {
let graph = create_mutual_recursion_graph();
let collapsed = collapse_sccs(&graph).expect("Failed");
let entity_ids: Vec<i64> = graph.all_entity_ids().expect("Failed to get IDs");
if let Some(supernode) = collapsed.supernode_for(entity_ids[0]) {
let members = collapsed.members_of(supernode);
assert!(members.is_some(), "Existing supernode should have members");
assert_eq!(members.unwrap().len(), 2, "Should have 2 members");
}
let non_existent = collapsed.members_of(99999);
assert!(
non_existent.is_none(),
"Non-existent supernode should return None"
);
}
#[test]
fn test_is_trivial() {
let graph = create_mutual_recursion_graph();
let collapsed = collapse_sccs(&graph).expect("Failed");
let entity_ids: Vec<i64> = graph.all_entity_ids().expect("Failed to get IDs");
let scc0 = collapsed.supernode_for(entity_ids[0]);
let scc2 = collapsed.supernode_for(entity_ids[2]);
if let Some(scc_id) = scc0 {
assert!(
!collapsed.is_trivial(scc_id),
"Multi-node SCC should not be trivial"
);
}
if let Some(scc_id) = scc2 {
assert!(
collapsed.is_trivial(scc_id),
"Single node SCC should be trivial"
);
}
assert!(
!collapsed.is_trivial(99999),
"Non-existent SCC should return false"
);
}
#[test]
fn test_non_trivial_count() {
let graph = create_mutual_recursion_graph();
let collapsed = collapse_sccs(&graph).expect("Failed");
assert_eq!(collapsed.non_trivial_count(), 1);
let dag = create_dag();
let dag_collapsed = collapse_sccs(&dag).expect("Failed");
assert_eq!(dag_collapsed.non_trivial_count(), 0);
let triangle = create_triangle_scc();
let triangle_collapsed = collapse_sccs(&triangle).expect("Failed");
assert_eq!(triangle_collapsed.non_trivial_count(), 1);
}
#[test]
fn test_non_trivial_nodes() {
let graph = create_mutual_recursion_graph();
let collapsed = collapse_sccs(&graph).expect("Failed");
let cyclic_nodes = collapsed.non_trivial_nodes();
assert_eq!(cyclic_nodes.len(), 2);
let entity_ids: Vec<i64> = graph.all_entity_ids().expect("Failed to get IDs");
assert!(cyclic_nodes.contains(&entity_ids[0]));
assert!(cyclic_nodes.contains(&entity_ids[1]));
assert!(!cyclic_nodes.contains(&entity_ids[2]));
assert!(!cyclic_nodes.contains(&entity_ids[3]));
assert!(!cyclic_nodes.contains(&entity_ids[4]));
}
#[test]
fn test_collapse_sccs_with_progress() {
use crate::progress::NoProgress;
let graph = create_mutual_recursion_graph();
let progress = NoProgress;
let result_with = collapse_sccs_with_progress(&graph, &progress).expect("Failed");
let result_without = collapse_sccs(&graph).expect("Failed");
assert_eq!(
result_with.num_sccs, result_without.num_sccs,
"Progress and non-progress results should match"
);
assert_eq!(
result_with.supernode_edges.len(),
result_without.supernode_edges.len(),
"Edges should match"
);
assert_eq!(
result_with.non_trivial_count(),
result_without.non_trivial_count(),
"Non-trivial count should match"
);
}
#[test]
fn test_collapse_sccs_empty_with_progress() {
use crate::progress::NoProgress;
let graph = create_empty_graph();
let progress = NoProgress;
let result = collapse_sccs_with_progress(&graph, &progress);
assert!(result.is_ok());
let collapsed = result.unwrap();
assert_eq!(collapsed.num_sccs, 0);
assert_eq!(collapsed.supernode_edges.len(), 0);
}
#[test]
fn test_condensation_is_dag() {
let graph = create_mutual_recursion_graph();
let collapsed = collapse_sccs(&graph).expect("Failed");
let mut adj: AHashMap<i64, Vec<i64>> = AHashMap::new();
for &(from, to) in &collapsed.supernode_edges {
adj.entry(from).or_insert_with(Vec::new).push(to);
}
let mut visited = AHashSet::new();
let mut rec_stack = AHashSet::new();
fn has_cycle(
node: i64,
adj: &AHashMap<i64, Vec<i64>>,
visited: &mut AHashSet<i64>,
rec_stack: &mut AHashSet<i64>,
) -> bool {
visited.insert(node);
rec_stack.insert(node);
if let Some(neighbors) = adj.get(&node) {
for &neighbor in neighbors {
if !visited.contains(&neighbor) {
if has_cycle(neighbor, adj, visited, rec_stack) {
return true;
}
} else if rec_stack.contains(&neighbor) {
return true;
}
}
}
rec_stack.remove(&node);
false
}
for (&supernode, _) in &collapsed.supernode_members {
if !visited.contains(&supernode) {
assert!(
!has_cycle(supernode, &adj, &mut visited, &mut rec_stack),
"Condensation graph should be acyclic"
);
}
}
}
}