use ahash::AHashMap;
use std::collections::HashSet;
use crate::progress::ProgressCallback;
use crate::{errors::SqliteGraphError, graph::SqliteGraph};
pub fn transitive_reduction(graph: &SqliteGraph) -> Result<HashSet<(i64, i64)>, SqliteGraphError> {
let closure = super::transitive_closure::transitive_closure(graph, None)?;
let mut essential_edges = HashSet::new();
let all_ids = graph.all_entity_ids()?;
for &from_id in &all_ids {
let outgoing = graph.fetch_outgoing(from_id)?;
for &to_id in &outgoing {
if is_reachable_via_intermediate(&closure, from_id, to_id) {
continue;
} else {
essential_edges.insert((from_id, to_id));
}
}
}
Ok(essential_edges)
}
pub fn transitive_reduction_with_progress<F>(
graph: &SqliteGraph,
progress: &F,
) -> Result<HashSet<(i64, i64)>, SqliteGraphError>
where
F: ProgressCallback,
{
let closure =
super::transitive_closure::transitive_closure_with_progress(graph, None, progress)?;
let mut essential_edges = HashSet::new();
let all_ids = graph.all_entity_ids()?;
let total_nodes = all_ids.len();
for (idx, &from_id) in all_ids.iter().enumerate() {
progress.on_progress(
idx + 1,
Some(total_nodes),
&format!("Transitive reduction: source {}/{}", idx + 1, total_nodes),
);
let outgoing = graph.fetch_outgoing(from_id)?;
for &to_id in &outgoing {
if is_reachable_via_intermediate(&closure, from_id, to_id) {
continue;
} else {
essential_edges.insert((from_id, to_id));
}
}
}
progress.on_complete();
Ok(essential_edges)
}
fn is_reachable_via_intermediate(
closure: &AHashMap<(i64, i64), bool>,
from_id: i64,
to_id: i64,
) -> bool {
if !closure.get(&(from_id, to_id)).copied().unwrap_or(false) {
return false;
}
for (&(src, dst), _) in closure.iter() {
if src == from_id && dst != to_id {
if closure.get(&(dst, to_id)).copied().unwrap_or(false) {
return true;
}
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{GraphEdge, GraphEntity};
fn create_linear_graph() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = GraphEntity {
id: 0,
kind: "test".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("test_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
for i in 0..entity_ids.len().saturating_sub(1) {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[i + 1],
edge_type: "connects".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
fn create_diamond_graph() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = GraphEntity {
id: 0,
kind: "test".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("test_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let edges = vec![(0, 1), (0, 2), (1, 3), (2, 3), (0, 3)];
for (from_idx, to_idx) in edges {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[from_idx],
to_id: entity_ids[to_idx],
edge_type: "connects".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
fn create_fully_connected_graph() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = GraphEntity {
id: 0,
kind: "test".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("test_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
for i in 0..entity_ids.len() {
for j in (i + 1)..entity_ids.len() {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[j],
edge_type: "connects".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
}
graph
}
#[test]
fn test_transitive_reduction_empty() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let result = transitive_reduction(&graph);
assert!(result.is_ok(), "transitive_reduction failed");
let essential = result.unwrap();
assert_eq!(essential.len(), 0, "Expected empty set for empty graph");
}
#[test]
fn test_transitive_reduction_single_node() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let entity = GraphEntity {
id: 0,
kind: "test".to_string(),
name: "single_node".to_string(),
file_path: Some("test.rs".to_string()),
data: serde_json::json!({}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
let result = transitive_reduction(&graph);
assert!(result.is_ok(), "transitive_reduction failed");
let essential = result.unwrap();
assert_eq!(essential.len(), 0, "Expected empty set for single node");
}
#[test]
fn test_transitive_reduction_linear_chain() {
let graph = create_linear_graph();
let _entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let result = transitive_reduction(&graph);
assert!(result.is_ok(), "transitive_reduction failed");
let essential = result.unwrap();
assert!(
essential.len() <= 3,
"Should have at most 3 essential edges, got {}",
essential.len()
);
}
#[test]
fn test_transitive_reduction_diamond() {
let graph = create_diamond_graph();
let _entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let result = transitive_reduction(&graph);
assert!(result.is_ok(), "transitive_reduction failed");
let essential = result.unwrap();
assert!(
essential.len() <= 5,
"Should have at most 5 edges, got {}",
essential.len()
);
}
#[test]
fn test_transitive_reduction_fully_connected() {
let graph = create_fully_connected_graph();
let _entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let result = transitive_reduction(&graph);
assert!(result.is_ok(), "transitive_reduction failed");
let essential = result.unwrap();
assert!(
essential.len() <= 6,
"Should have at most 6 edges, got {}",
essential.len()
);
}
#[test]
fn test_transitive_reduction_preserves_reachability() {
let graph = create_diamond_graph();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let original_closure = super::super::transitive_closure::transitive_closure(&graph, None)
.expect("Failed to compute original closure");
let essential = transitive_reduction(&graph).expect("Failed to compute reduction");
for (&(src, dst), _) in original_closure.iter() {
if src == dst {
continue; }
}
for &from_id in &entity_ids {
let outgoing = graph
.fetch_outgoing(from_id)
.expect("Failed to get outgoing");
for &to_id in &outgoing {
if essential.contains(&(from_id, to_id)) {
assert!(true, "Essential edge exists in original");
}
}
}
}
#[test]
fn test_transitive_reduction_with_progress() {
use crate::progress::NoProgress;
let graph = create_diamond_graph();
let progress = NoProgress;
let result = transitive_reduction_with_progress(&graph, &progress);
assert!(result.is_ok(), "transitive_reduction_with_progress failed");
let essential = result.unwrap();
assert!(
essential.len() <= 5,
"Should have at most 5 edges, got {}",
essential.len()
);
}
#[test]
fn test_transitive_reduction_deterministic() {
let graph = create_diamond_graph();
let result1 = transitive_reduction(&graph);
let result2 = transitive_reduction(&graph);
assert!(result1.is_ok(), "First transitive_reduction failed");
assert!(result2.is_ok(), "Second transitive_reduction failed");
let essential1 = result1.unwrap();
let essential2 = result2.unwrap();
assert_eq!(
essential1.len(),
essential2.len(),
"Different number of essential edges"
);
assert_eq!(essential1, essential2, "Essential edges differ");
}
}