use std::collections::HashSet;
use ahash::{AHashMap, AHashSet};
use crate::{errors::SqliteGraphError, graph::SqliteGraph};
#[derive(Debug, Clone)]
pub struct SccResult {
pub components: Vec<HashSet<i64>>,
pub node_to_component: AHashMap<i64, usize>,
pub condensed_edges: Vec<(usize, usize)>,
}
impl SccResult {
pub fn non_trivial_count(&self) -> usize {
self.components.iter().filter(|c| c.len() > 1).count()
}
pub fn non_trivial_nodes(&self) -> AHashSet<i64> {
self.components
.iter()
.filter(|c| c.len() > 1)
.flat_map(|c| c.iter().copied())
.collect()
}
pub fn is_in_cycle(&self, node: i64) -> bool {
if let Some(&component_idx) = self.node_to_component.get(&node) {
self.components[component_idx].len() > 1
} else {
false
}
}
}
pub fn strongly_connected_components(graph: &SqliteGraph) -> Result<SccResult, SqliteGraphError> {
let all_ids = graph.all_entity_ids()?;
if all_ids.is_empty() {
return Ok(SccResult {
components: Vec::new(),
node_to_component: AHashMap::new(),
condensed_edges: Vec::new(),
});
}
let mut index_counter: i64 = 0;
let mut stack: Vec<i64> = Vec::new();
let mut on_stack: AHashSet<i64> = AHashSet::new();
let mut indices: AHashMap<i64, i64> = AHashMap::new();
let mut lowlink: AHashMap<i64, i64> = AHashMap::new();
let mut components: Vec<HashSet<i64>> = Vec::new();
let mut node_to_component: AHashMap<i64, usize> = AHashMap::new();
for &node in &all_ids {
if !indices.contains_key(&node) {
strongconnect(
graph,
node,
&mut index_counter,
&mut stack,
&mut on_stack,
&mut indices,
&mut lowlink,
&mut components,
&mut node_to_component,
)?;
}
}
let condensed_edges = build_condensed_dag(graph, &node_to_component, &components)?;
Ok(SccResult {
components,
node_to_component,
condensed_edges,
})
}
fn strongconnect(
graph: &SqliteGraph,
v: i64,
index_counter: &mut i64,
stack: &mut Vec<i64>,
on_stack: &mut AHashSet<i64>,
indices: &mut AHashMap<i64, i64>,
lowlink: &mut AHashMap<i64, i64>,
components: &mut Vec<HashSet<i64>>,
node_to_component: &mut AHashMap<i64, usize>,
) -> Result<(), SqliteGraphError> {
indices.insert(v, *index_counter);
lowlink.insert(v, *index_counter);
*index_counter += 1;
stack.push(v);
on_stack.insert(v);
for &w in &graph.fetch_outgoing(v)? {
if !indices.contains_key(&w) {
strongconnect(
graph,
w,
index_counter,
stack,
on_stack,
indices,
lowlink,
components,
node_to_component,
)?;
lowlink.insert(
v,
(*lowlink.get(&v).unwrap()).min(*lowlink.get(&w).unwrap()),
);
} else if on_stack.contains(&w) {
lowlink.insert(
v,
(*lowlink.get(&v).unwrap()).min(*indices.get(&w).unwrap()),
);
}
}
if lowlink.get(&v) == indices.get(&v) {
let mut component = HashSet::new();
loop {
let w = stack.pop().unwrap();
on_stack.remove(&w);
component.insert(w);
node_to_component.insert(w, components.len());
if w == v {
break;
}
}
components.push(component);
}
Ok(())
}
fn build_condensed_dag(
graph: &SqliteGraph,
node_to_component: &AHashMap<i64, usize>,
_components: &[HashSet<i64>],
) -> Result<Vec<(usize, usize)>, SqliteGraphError> {
let mut edge_set: AHashSet<(usize, usize)> = AHashSet::new();
for &from_node in &graph.all_entity_ids()? {
if let Some(&from_comp) = node_to_component.get(&from_node) {
for &to_node in &graph.fetch_outgoing(from_node)? {
if let Some(&to_comp) = node_to_component.get(&to_node) {
if from_comp != to_comp {
edge_set.insert((from_comp, to_comp));
}
}
}
}
}
let mut edges: Vec<(usize, usize)> = edge_set.into_iter().collect();
edges.sort();
edges.dedup();
Ok(edges)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::GraphEntity;
fn create_test_graph() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..10 {
let entity = GraphEntity {
id: 0,
kind: "test".to_string(),
name: format!("test_{}", i),
file_path: Some(format!("test_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
graph
}
fn create_linear_chain_graph() -> SqliteGraph {
let graph = create_test_graph();
let entity_ids = graph.all_entity_ids().expect("Failed to get IDs");
for i in 0..entity_ids.len().saturating_sub(1) {
let edge = crate::GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[i + 1],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).ok();
}
graph
}
fn create_simple_cycle_graph() -> SqliteGraph {
let graph = create_test_graph();
let entity_ids = graph.all_entity_ids().expect("Failed to get IDs");
let cycle = vec![(0, 1), (1, 2), (2, 0)];
for (from_idx, to_idx) in cycle {
let edge = crate::GraphEdge {
id: 0,
from_id: entity_ids[from_idx],
to_id: entity_ids[to_idx],
edge_type: "cycle".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).ok();
}
graph
}
fn create_mutual_recursion_graph() -> SqliteGraph {
let graph = create_test_graph();
let entity_ids = graph.all_entity_ids().expect("Failed to get IDs");
let edges = vec![(0, 1), (1, 0), (2, 3), (3, 4)];
for (from_idx, to_idx) in edges {
let edge = crate::GraphEdge {
id: 0,
from_id: entity_ids[from_idx],
to_id: entity_ids[to_idx],
edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).ok();
}
graph
}
#[test]
fn test_scc_empty_graph() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let result = strongly_connected_components(&graph);
assert!(result.is_ok());
let scc = result.unwrap();
assert_eq!(scc.components.len(), 0);
assert_eq!(scc.node_to_component.len(), 0);
assert_eq!(scc.condensed_edges.len(), 0);
}
#[test]
fn test_scc_single_node() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let entity = GraphEntity {
id: 0,
kind: "test".to_string(),
name: "single".to_string(),
file_path: Some("single.rs".to_string()),
data: serde_json::json!({}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
let result = strongly_connected_components(&graph);
assert!(result.is_ok());
let scc = result.unwrap();
assert_eq!(scc.components.len(), 1);
assert_eq!(scc.node_to_component.len(), 1);
assert_eq!(scc.components[0].len(), 1); assert_eq!(scc.non_trivial_count(), 0);
}
#[test]
fn test_scc_linear_chain() {
let graph = create_linear_chain_graph();
let result = strongly_connected_components(&graph);
assert!(result.is_ok());
let scc = result.unwrap();
assert_eq!(scc.components.len(), 10);
assert_eq!(scc.node_to_component.len(), 10);
assert_eq!(scc.non_trivial_count(), 0);
assert_eq!(scc.condensed_edges.len(), 9);
}
#[test]
fn test_scc_simple_cycle() {
let graph = create_simple_cycle_graph();
let result = strongly_connected_components(&graph);
assert!(result.is_ok());
let scc = result.unwrap();
assert_eq!(scc.components.len(), 8); assert_eq!(scc.node_to_component.len(), 10);
assert_eq!(scc.non_trivial_count(), 1);
let cycle_component = scc
.components
.iter()
.find(|c| c.len() == 3)
.expect("Should have a 3-node SCC");
let entity_ids = graph.all_entity_ids().expect("Failed to get IDs");
assert!(cycle_component.contains(&entity_ids[0]));
assert!(cycle_component.contains(&entity_ids[1]));
assert!(cycle_component.contains(&entity_ids[2]));
for node in cycle_component {
assert!(scc.is_in_cycle(*node));
}
}
#[test]
fn test_scc_mutual_recursion() {
let graph = create_mutual_recursion_graph();
let result = strongly_connected_components(&graph);
assert!(result.is_ok());
let scc = result.unwrap();
assert_eq!(scc.components.len(), 9);
assert_eq!(scc.non_trivial_count(), 1);
let recursion_component = scc
.components
.iter()
.find(|c| c.len() == 2)
.expect("Should have a 2-node SCC");
assert_eq!(recursion_component.len(), 2);
}
#[test]
fn test_scc_condensed_dag() {
let graph = create_mutual_recursion_graph();
let result = strongly_connected_components(&graph);
assert!(result.is_ok());
let scc = result.unwrap();
for &(from, to) in &scc.condensed_edges {
assert_ne!(from, to, "Condensed DAG should not have self-loops");
}
for &(from, to) in &scc.condensed_edges {
assert!(from < scc.components.len());
assert!(to < scc.components.len());
}
}
}