use std::collections::{HashSet, VecDeque};
use ahash::AHashMap;
use crate::graph::SqliteGraph;
#[derive(Debug, Clone)]
pub enum TopoError {
CycleDetected {
cycle: Vec<i64>,
explanation: String,
},
}
impl std::fmt::Display for TopoError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TopoError::CycleDetected { cycle, explanation } => {
write!(f, "{}: Cycle detected: {:?}", explanation, cycle)
}
}
}
}
impl std::error::Error for TopoError {}
pub fn topological_sort(graph: &SqliteGraph) -> Result<Vec<i64>, TopoError> {
let all_nodes = graph
.all_entity_ids()
.map_err(|e| TopoError::CycleDetected {
cycle: vec![],
explanation: format!("Failed to get nodes: {}", e),
})?;
if all_nodes.is_empty() {
return Ok(Vec::new());
}
let scc = crate::algo::scc::strongly_connected_components(graph).map_err(|e| {
TopoError::CycleDetected {
cycle: vec![],
explanation: format!("Failed to compute SCC: {}", e),
}
})?;
let non_trivial_sccs: Vec<_> = scc.components.into_iter().filter(|c| c.len() > 1).collect();
if !non_trivial_sccs.is_empty() {
let cycle = extract_cycle_path(graph, &non_trivial_sccs[0]);
return Err(TopoError::CycleDetected {
cycle,
explanation: format!(
"Found {} cycle(s) - graph is not a DAG",
non_trivial_sccs.len()
),
});
}
let mut in_degree: AHashMap<i64, usize> = AHashMap::new();
for &node in &all_nodes {
in_degree.insert(node, 0);
}
for &node in &all_nodes {
for target in graph
.fetch_outgoing(node)
.map_err(|e| TopoError::CycleDetected {
cycle: vec![],
explanation: format!("Failed to get outgoing edges: {}", e),
})?
{
*in_degree.entry(target).or_insert(0) += 1;
}
}
let mut queue: VecDeque<i64> = in_degree
.iter()
.filter(|(_, deg)| **deg == 0)
.map(|(&node, _)| node)
.collect();
let mut result = Vec::new();
while let Some(node) = queue.pop_front() {
result.push(node);
for target in graph
.fetch_outgoing(node)
.map_err(|e| TopoError::CycleDetected {
cycle: vec![],
explanation: format!("Failed to get outgoing edges: {}", e),
})?
{
let deg = in_degree.get_mut(&target).unwrap();
*deg -= 1;
if *deg == 0 {
queue.push_back(target);
}
}
}
if result.len() != all_nodes.len() {
return Err(TopoError::CycleDetected {
cycle: vec![],
explanation: "Graph contains cycle".to_string(),
});
}
Ok(result)
}
fn extract_cycle_path(graph: &SqliteGraph, scc: &HashSet<i64>) -> Vec<i64> {
let &start = scc.iter().next().unwrap_or(&1);
let mut path = vec![start];
let mut visited = HashSet::new();
visited.insert(start);
loop {
let current = *path.last().unwrap_or(&start);
let mut found_next = false;
if let Ok(outgoing) = graph.fetch_outgoing(current) {
for &next in &outgoing {
if scc.contains(&next) {
if next == start {
path.push(next);
return path;
} else if !visited.contains(&next) {
path.push(next);
visited.insert(next);
found_next = true;
break;
} else if path.len() > 1 {
if let Some(cycle_start_idx) = path.iter().position(|&n| n == next) {
let cycle: Vec<i64> = path[cycle_start_idx..].to_vec();
return cycle;
}
}
}
}
}
if !found_next {
return path;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{GraphEdge, GraphEntity};
fn create_linear_chain_graph() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = GraphEntity {
id: 0,
kind: "test".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("test_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
for i in 0..entity_ids.len().saturating_sub(1) {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[i + 1],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
fn create_diamond_dag() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = GraphEntity {
id: 0,
kind: "test".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("test_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let edges = vec![(0, 1), (1, 3), (0, 2), (2, 3)];
for (from_idx, to_idx) in edges {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[from_idx],
to_id: entity_ids[to_idx],
edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
fn create_cycle_graph() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..3 {
let entity = GraphEntity {
id: 0,
kind: "test".to_string(),
name: format!("cycle_{}", i),
file_path: Some(format!("cycle_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
for i in 0..3 {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[(i + 1) % 3],
edge_type: "cycle".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
#[test]
fn test_topo_sort_empty() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let result = topological_sort(&graph);
assert!(result.is_ok(), "Topological sort failed on empty graph");
let ordering = result.unwrap();
assert_eq!(ordering.len(), 0, "Expected empty ordering for empty graph");
}
#[test]
fn test_topo_sort_single_node() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let entity = GraphEntity {
id: 0,
kind: "test".to_string(),
name: "single".to_string(),
file_path: Some("single.rs".to_string()),
data: serde_json::json!({}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
let result = topological_sort(&graph);
assert!(result.is_ok(), "Topological sort failed on single node");
let ordering = result.unwrap();
assert_eq!(ordering.len(), 1, "Expected single node in ordering");
let entity_ids = graph.list_entity_ids().expect("Failed to get IDs");
assert_eq!(
ordering[0], entity_ids[0],
"Ordering should contain the node"
);
}
#[test]
fn test_topo_sort_linear_chain() {
let graph = create_linear_chain_graph();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let result = topological_sort(&graph);
assert!(result.is_ok(), "Topological sort failed on linear chain");
let ordering = result.unwrap();
assert_eq!(ordering.len(), 4, "Expected 4 nodes in ordering");
for i in 0..entity_ids.len().saturating_sub(1) {
let from = entity_ids[i];
let to = entity_ids[i + 1];
let from_pos = ordering.iter().position(|&n| n == from).unwrap_or(999);
let to_pos = ordering.iter().position(|&n| n == to).unwrap_or(999);
assert!(
from_pos < to_pos,
"Edge {} -> {} violates topological order ({} at {}, {} at {})",
from,
to,
from,
from_pos,
to,
to_pos
);
}
}
#[test]
fn test_topo_sort_diamond() {
let graph = create_diamond_dag();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let result = topological_sort(&graph);
assert!(result.is_ok(), "Topological sort failed on diamond DAG");
let ordering = result.unwrap();
assert_eq!(ordering.len(), 4, "Expected 4 nodes in ordering");
let node_0 = entity_ids[0];
let node_1 = entity_ids[1];
let node_2 = entity_ids[2];
let node_3 = entity_ids[3];
let pos_0 = ordering.iter().position(|&n| n == node_0).unwrap();
let pos_1 = ordering.iter().position(|&n| n == node_1).unwrap();
let pos_2 = ordering.iter().position(|&n| n == node_2).unwrap();
let pos_3 = ordering.iter().position(|&n| n == node_3).unwrap();
assert!(pos_0 < pos_1, "0 should come before 1");
assert!(pos_0 < pos_2, "0 should come before 2");
assert!(pos_0 < pos_3, "0 should come before 3");
assert!(pos_1 < pos_3, "1 should come before 3");
assert!(pos_2 < pos_3, "2 should come before 3");
}
#[test]
fn test_topo_sort_cycle() {
let graph = create_cycle_graph();
let result = topological_sort(&graph);
assert!(
result.is_err(),
"Topological sort should fail on cyclic graph"
);
let err = result.unwrap_err();
match err {
TopoError::CycleDetected { cycle, explanation } => {
assert!(!cycle.is_empty(), "Cycle should not be empty");
assert!(
explanation.contains("cycle"),
"Explanation should mention cycles"
);
assert!(cycle.len() >= 3, "Cycle should have at least 3 nodes");
}
}
}
#[test]
fn test_topo_sort_cycle_has_path() {
let graph = create_cycle_graph();
let result = topological_sort(&graph);
assert!(result.is_err(), "Should detect cycle");
let err = result.unwrap_err();
match err {
TopoError::CycleDetected { cycle, .. } => {
assert!(cycle.len() >= 3, "Cycle should have at least 3 nodes");
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let valid_nodes: std::collections::HashSet<i64> = entity_ids.into_iter().collect();
for node in &cycle {
assert!(
valid_nodes.contains(node),
"Cycle node {} should be in graph",
node
);
}
}
}
}
}