use std::fmt;
use ahash::{AHashMap, AHashSet};
use serde_json::Value;
use crate::algo::topological_sort::{TopoError, topological_sort};
use crate::graph::SqliteGraph;
use crate::progress::ProgressCallback;
#[derive(Debug, Clone)]
pub enum CriticalPathError {
NotADag {
cycle: Vec<i64>,
},
InvalidWeight {
from: i64,
to: i64,
reason: String,
},
}
impl fmt::Display for CriticalPathError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CriticalPathError::NotADag { cycle } => {
write!(
f,
"Critical path analysis requires a DAG: cycle detected: {:?}",
cycle
)
}
CriticalPathError::InvalidWeight { from, to, reason } => {
write!(f, "Invalid weight for edge {} -> {}: {}", from, to, reason)
}
}
}
}
impl std::error::Error for CriticalPathError {}
#[derive(Debug, Clone)]
pub struct CriticalPathResult {
pub path: Vec<i64>,
pub distance: f64,
pub distances: AHashMap<i64, f64>,
pub predecessors: AHashMap<i64, Option<i64>>,
pub topological_order: Vec<i64>,
}
impl CriticalPathResult {
pub fn bottlenecks(&self) -> AHashSet<i64> {
self.path.iter().copied().collect()
}
pub fn slack(&self) -> AHashMap<i64, f64> {
self.distances
.iter()
.map(|(&node, &dist)| (node, self.distance - dist))
.collect()
}
pub fn is_bottleneck(&self, node: i64) -> bool {
self.path.contains(&node)
}
}
pub type WeightCallback = dyn Fn(i64, i64, &Value) -> f64;
pub fn default_weight_fn(_from: i64, _to: i64, _edge_data: &Value) -> f64 {
1.0
}
pub fn critical_path(
graph: &SqliteGraph,
weight_fn: &WeightCallback,
) -> Result<CriticalPathResult, CriticalPathError> {
let topo_order = topological_sort(graph).map_err(|e| match e {
TopoError::CycleDetected { cycle, .. } => CriticalPathError::NotADag { cycle },
})?;
if topo_order.is_empty() {
return Ok(CriticalPathResult {
path: Vec::new(),
distance: 0.0,
distances: AHashMap::new(),
predecessors: AHashMap::new(),
topological_order: Vec::new(),
});
}
let mut distances: AHashMap<i64, f64> = AHashMap::new();
let mut predecessors: AHashMap<i64, Option<i64>> = AHashMap::new();
for &node in &topo_order {
distances.insert(node, 0.0);
predecessors.insert(node, None);
}
for &u in &topo_order {
let dist_u = *distances.get(&u).unwrap_or(&0.0);
let outgoing = graph
.fetch_outgoing(u)
.map_err(|e| CriticalPathError::InvalidWeight {
from: u,
to: 0, reason: format!("failed to fetch outgoing edges: {}", e),
})?;
for v in outgoing {
let edge_data = &serde_json::json!({});
let weight = weight_fn(u, v, edge_data);
if !weight.is_finite() {
return Err(CriticalPathError::InvalidWeight {
from: u,
to: v,
reason: format!("weight is not finite: {}", weight),
});
}
let new_dist = dist_u + weight;
let dist_v = distances.get_mut(&v).unwrap();
if new_dist > *dist_v {
*dist_v = new_dist;
predecessors.insert(v, Some(u));
}
}
}
let mut max_distance = 0.0;
let mut end_node = None;
for (&node, &dist) in &distances {
if dist > max_distance {
max_distance = dist;
end_node = Some(node);
}
}
let end_node = match end_node {
Some(node) => node,
None => {
let first = topo_order.first().copied().unwrap_or(0);
return Ok(CriticalPathResult {
path: vec![first],
distance: 0.0,
distances,
predecessors,
topological_order: topo_order,
});
}
};
let mut path = Vec::new();
let mut current = Some(end_node);
while let Some(node) = current {
path.push(node);
current = *predecessors.get(&node).unwrap_or(&None);
}
path.reverse();
Ok(CriticalPathResult {
path,
distance: max_distance,
distances,
predecessors,
topological_order: topo_order,
})
}
pub fn critical_path_with_progress<F>(
graph: &SqliteGraph,
weight_fn: &WeightCallback,
progress: &F,
) -> Result<CriticalPathResult, CriticalPathError>
where
F: ProgressCallback,
{
progress.on_progress(1, Some(3), "Validating DAG structure");
let topo_order = topological_sort(graph).map_err(|e| match e {
TopoError::CycleDetected { cycle, .. } => CriticalPathError::NotADag { cycle },
})?;
if topo_order.is_empty() {
progress.on_complete();
return Ok(CriticalPathResult {
path: Vec::new(),
distance: 0.0,
distances: AHashMap::new(),
predecessors: AHashMap::new(),
topological_order: Vec::new(),
});
}
progress.on_progress(2, Some(3), "Computing critical path");
let total_nodes = topo_order.len();
let mut distances: AHashMap<i64, f64> = AHashMap::new();
let mut predecessors: AHashMap<i64, Option<i64>> = AHashMap::new();
for &node in &topo_order {
distances.insert(node, 0.0);
predecessors.insert(node, None);
}
for (i, &u) in topo_order.iter().enumerate() {
let dist_u = *distances.get(&u).unwrap_or(&0.0);
let outgoing = graph
.fetch_outgoing(u)
.map_err(|e| CriticalPathError::InvalidWeight {
from: u,
to: 0,
reason: format!("failed to fetch outgoing edges: {}", e),
})?;
for v in outgoing {
let edge_data = &serde_json::json!({});
let weight = weight_fn(u, v, edge_data);
if !weight.is_finite() {
return Err(CriticalPathError::InvalidWeight {
from: u,
to: v,
reason: format!("weight is not finite: {}", weight),
});
}
let new_dist = dist_u + weight;
let dist_v = distances.get_mut(&v).unwrap();
if new_dist > *dist_v {
*dist_v = new_dist;
predecessors.insert(v, Some(u));
}
}
progress.on_progress(i + 1, Some(total_nodes), "Processing nodes");
}
let mut max_distance = 0.0;
let mut end_node = None;
for (&node, &dist) in &distances {
if dist > max_distance {
max_distance = dist;
end_node = Some(node);
}
}
let end_node = match end_node {
Some(node) => node,
None => {
let first = topo_order.first().copied().unwrap_or(0);
progress.on_complete();
return Ok(CriticalPathResult {
path: vec![first],
distance: 0.0,
distances,
predecessors,
topological_order: topo_order,
});
}
};
progress.on_progress(3, Some(3), "Reconstructing path");
let mut path = Vec::new();
let mut current = Some(end_node);
while let Some(node) = current {
path.push(node);
current = *predecessors.get(&node).unwrap_or(&None);
}
path.reverse();
progress.on_complete();
Ok(CriticalPathResult {
path,
distance: max_distance,
distances,
predecessors,
topological_order: topo_order,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{GraphEdge, GraphEntity};
fn create_linear_weighted_dag() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = GraphEntity {
id: 0,
kind: "task".to_string(),
name: format!("task_{}", i),
file_path: Some(format!("task_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let weights = vec![5.0, 3.0, 2.0];
for (i, &weight) in weights.iter().enumerate() {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[i + 1],
edge_type: "depends".to_string(),
data: serde_json::json!({"duration": weight}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
fn create_diamond_weighted_dag() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = GraphEntity {
id: 0,
kind: "task".to_string(),
name: format!("task_{}", i),
file_path: Some(format!("task_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let edge1 = GraphEdge {
id: 0,
from_id: entity_ids[0],
to_id: entity_ids[1],
edge_type: "depends".to_string(),
data: serde_json::json!({"duration": 5.0}),
};
graph.insert_edge(&edge1).expect("Failed to insert edge");
let edge2 = GraphEdge {
id: 0,
from_id: entity_ids[1],
to_id: entity_ids[3],
edge_type: "depends".to_string(),
data: serde_json::json!({"duration": 4.0}),
};
graph.insert_edge(&edge2).expect("Failed to insert edge");
let edge3 = GraphEdge {
id: 0,
from_id: entity_ids[0],
to_id: entity_ids[2],
edge_type: "depends".to_string(),
data: serde_json::json!({"duration": 3.0}),
};
graph.insert_edge(&edge3).expect("Failed to insert edge");
let edge4 = GraphEdge {
id: 0,
from_id: entity_ids[2],
to_id: entity_ids[3],
edge_type: "depends".to_string(),
data: serde_json::json!({"duration": 2.0}),
};
graph.insert_edge(&edge4).expect("Failed to insert edge");
graph
}
fn create_cycle_graph() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..3 {
let entity = GraphEntity {
id: 0,
kind: "task".to_string(),
name: format!("cycle_{}", i),
file_path: Some(format!("cycle_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
for i in 0..3 {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[(i + 1) % 3],
edge_type: "cycle".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
fn create_parallel_dag() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..5 {
let entity = GraphEntity {
id: 0,
kind: "task".to_string(),
name: format!("task_{}", i),
file_path: Some(format!("task_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let start_to_a = GraphEdge {
id: 0,
from_id: entity_ids[0],
to_id: entity_ids[1],
edge_type: "depends".to_string(),
data: serde_json::json!({"duration": 3.0}),
};
graph
.insert_edge(&start_to_a)
.expect("Failed to insert edge");
let start_to_b = GraphEdge {
id: 0,
from_id: entity_ids[0],
to_id: entity_ids[2],
edge_type: "depends".to_string(),
data: serde_json::json!({"duration": 5.0}),
};
graph
.insert_edge(&start_to_b)
.expect("Failed to insert edge");
let start_to_c = GraphEdge {
id: 0,
from_id: entity_ids[0],
to_id: entity_ids[3],
edge_type: "depends".to_string(),
data: serde_json::json!({"duration": 2.0}),
};
graph
.insert_edge(&start_to_c)
.expect("Failed to insert edge");
let a_to_end = GraphEdge {
id: 0,
from_id: entity_ids[1],
to_id: entity_ids[4],
edge_type: "depends".to_string(),
data: serde_json::json!({"duration": 1.0}),
};
graph.insert_edge(&a_to_end).expect("Failed to insert edge");
let b_to_end = GraphEdge {
id: 0,
from_id: entity_ids[2],
to_id: entity_ids[4],
edge_type: "depends".to_string(),
data: serde_json::json!({"duration": 3.0}),
};
graph.insert_edge(&b_to_end).expect("Failed to insert edge");
let c_to_end = GraphEdge {
id: 0,
from_id: entity_ids[3],
to_id: entity_ids[4],
edge_type: "depends".to_string(),
data: serde_json::json!({"duration": 2.0}),
};
graph.insert_edge(&c_to_end).expect("Failed to insert edge");
graph
}
fn duration_weight_fn(_from: i64, _to: i64, edge_data: &Value) -> f64 {
edge_data
.get("duration")
.and_then(|v| v.as_f64())
.unwrap_or(1.0)
}
#[test]
fn test_critical_path_linear_chain() {
let graph = create_linear_weighted_dag();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let result = critical_path(&graph, &duration_weight_fn)
.expect("Critical path should succeed on DAG");
assert_eq!(result.path.len(), 4, "Path should have 4 nodes");
assert_eq!(
result.path, entity_ids,
"Path should contain all nodes in order"
);
assert!(result.distance > 0.0, "Distance should be positive");
}
#[test]
fn test_critical_path_diamond_selects_heavier_branch() {
let graph = create_diamond_weighted_dag();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let result = critical_path(&graph, &duration_weight_fn)
.expect("Critical path should succeed on DAG");
assert_eq!(result.path.len(), 3, "Path should have 3 nodes");
assert_eq!(result.path[0], entity_ids[0], "Path should start at A");
assert_eq!(result.path[2], entity_ids[3], "Path should end at D");
assert_eq!(
result.distance, 2.0,
"Distance should be 2 with default weight"
);
}
#[test]
fn test_critical_path_weight_extraction() {
let graph = create_linear_weighted_dag();
let custom_weight_fn = |_from: i64, _to: i64, edge_data: &Value| -> f64 {
edge_data
.get("duration")
.and_then(|v| v.as_f64())
.unwrap_or(999.0) };
let result =
critical_path(&graph, &custom_weight_fn).expect("Critical path should succeed");
assert_eq!(
result.distance, 2997.0,
"With empty edge_data, should use default weight"
);
}
#[test]
fn test_critical_path_default_weight() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = GraphEntity {
id: 0,
kind: "task".to_string(),
name: format!("task_{}", i),
file_path: Some(format!("task_{}.rs", i)),
data: serde_json::json!({}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
for i in 0..entity_ids.len().saturating_sub(1) {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[i + 1],
edge_type: "depends".to_string(),
data: serde_json::json!({}), };
graph.insert_edge(&edge).expect("Failed to insert edge");
}
let result =
critical_path(&graph, &default_weight_fn).expect("Critical path should succeed");
assert_eq!(result.path.len(), 4, "Path should have 4 nodes");
assert!(result.distance > 0.0, "Distance should be positive");
}
#[test]
fn test_critical_path_parallel_tasks() {
let graph = create_parallel_dag();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let result =
critical_path(&graph, &duration_weight_fn).expect("Critical path should succeed");
assert_eq!(result.path.len(), 3, "Path should have 3 nodes");
assert_eq!(result.path[0], entity_ids[0], "Path should start at Start");
assert_eq!(result.path[2], entity_ids[4], "Path should end at End");
assert!(result.distance > 0.0, "Distance should be positive");
}
#[test]
fn test_critical_path_cycle_detection() {
let graph = create_cycle_graph();
let result = critical_path(&graph, &default_weight_fn);
assert!(result.is_err(), "Critical path should fail on cyclic graph");
let err = result.unwrap_err();
match err {
CriticalPathError::NotADag { cycle } => {
assert!(!cycle.is_empty(), "Cycle should not be empty");
assert!(cycle.len() >= 3, "Cycle should have at least 3 nodes");
}
_ => panic!("Expected NotADag error"),
}
}
#[test]
fn test_critical_path_empty_graph() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let result = critical_path(&graph, &default_weight_fn)
.expect("Critical path should succeed on empty graph");
assert_eq!(result.path.len(), 0, "Path should be empty");
assert_eq!(result.distance, 0.0, "Distance should be 0");
assert!(result.distances.is_empty(), "Distances should be empty");
}
#[test]
fn test_critical_path_single_node() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let entity = GraphEntity {
id: 0,
kind: "task".to_string(),
name: "single".to_string(),
file_path: Some("single.rs".to_string()),
data: serde_json::json!({}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
let result = critical_path(&graph, &default_weight_fn)
.expect("Critical path should succeed on single node");
assert_eq!(result.path.len(), 1, "Path should have 1 node");
assert_eq!(result.distance, 0.0, "Distance should be 0");
}
#[test]
fn test_critical_path_bottlenecks() {
let graph = create_diamond_weighted_dag();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let result =
critical_path(&graph, &duration_weight_fn).expect("Critical path should succeed");
let bottlenecks = result.bottlenecks();
assert_eq!(bottlenecks.len(), 3, "Should have 3 bottlenecks");
assert!(
bottlenecks.contains(&entity_ids[0]),
"A should be a bottleneck"
);
assert!(
bottlenecks.contains(&entity_ids[1]),
"B should be a bottleneck"
);
assert!(
bottlenecks.contains(&entity_ids[3]),
"D should be a bottleneck"
);
assert!(
!bottlenecks.contains(&entity_ids[2]),
"C should NOT be a bottleneck"
);
}
#[test]
fn test_critical_path_slack() {
let graph = create_diamond_weighted_dag();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let result =
critical_path(&graph, &duration_weight_fn).expect("Critical path should succeed");
let slack = result.slack();
assert!(
slack.contains_key(&entity_ids[0]),
"A should have slack entry"
);
assert!(
slack.contains_key(&entity_ids[1]),
"B should have slack entry"
);
assert!(
slack.contains_key(&entity_ids[2]),
"C should have slack entry"
);
assert!(
slack.contains_key(&entity_ids[3]),
"D should have slack entry"
);
for (node, s) in &slack {
assert!(*s >= 0.0, "Node {} should have non-negative slack", node);
}
}
#[test]
fn test_critical_path_is_bottleneck() {
let graph = create_diamond_weighted_dag();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let result =
critical_path(&graph, &duration_weight_fn).expect("Critical path should succeed");
assert!(
result.is_bottleneck(entity_ids[0]),
"A should be a bottleneck"
);
assert!(
result.is_bottleneck(entity_ids[1]),
"B should be a bottleneck"
);
assert!(
result.is_bottleneck(entity_ids[3]),
"D should be a bottleneck"
);
assert!(
!result.is_bottleneck(entity_ids[2]),
"C should NOT be a bottleneck"
);
}
#[test]
fn test_critical_path_with_progress() {
let graph = create_linear_weighted_dag();
struct TestProgress {
call_count: std::sync::atomic::AtomicUsize,
}
impl ProgressCallback for TestProgress {
fn on_progress(&self, _current: usize, _total: Option<usize>, _message: &str) {
self.call_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
fn on_complete(&self) {
self.call_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
fn on_error(&self, _error: &dyn std::error::Error) {
self.call_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
}
let progress = TestProgress {
call_count: std::sync::atomic::AtomicUsize::new(0),
};
let result = critical_path_with_progress(&graph, &duration_weight_fn, &progress)
.expect("Critical path with progress should succeed");
assert_eq!(result.path.len(), 4, "Path should have 4 nodes");
assert!(
result.distance > 0.0,
"Distance should be positive, got {}",
result.distance
);
let call_count = progress
.call_count
.load(std::sync::atomic::Ordering::SeqCst);
assert!(call_count > 0, "Progress should have been called");
}
}