use std::collections::{HashMap, HashSet, VecDeque};
use crate::error::{CoreError, CoreResult, ErrorContext};
pub type TaskId = u64;
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct DepTaskNode {
pub id: TaskId,
pub name: String,
pub priority: i32,
pub estimated_cost: f64,
pub metadata: HashMap<String, String>,
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TopologicalAlgorithm {
Kahn,
DfsBased,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct DependencyGraphConfig {
pub enable_cycle_detection: bool,
pub max_depth: usize,
pub topological_order: TopologicalAlgorithm,
}
impl Default for DependencyGraphConfig {
fn default() -> Self {
Self {
enable_cycle_detection: true,
max_depth: 1000,
topological_order: TopologicalAlgorithm::Kahn,
}
}
}
pub struct DependencyGraph {
config: DependencyGraphConfig,
nodes: HashMap<TaskId, DepTaskNode>,
edges: HashMap<TaskId, Vec<TaskId>>,
rev_edges: HashMap<TaskId, Vec<TaskId>>,
next_id: TaskId,
}
impl DependencyGraph {
pub fn new(config: DependencyGraphConfig) -> Self {
Self {
config,
nodes: HashMap::new(),
edges: HashMap::new(),
rev_edges: HashMap::new(),
next_id: 0,
}
}
pub fn add_task(&mut self, name: &str, priority: i32) -> TaskId {
self.add_task_with_cost(name, priority, 1.0)
}
pub fn add_task_with_cost(&mut self, name: &str, priority: i32, cost: f64) -> TaskId {
let id = self.next_id;
self.next_id += 1;
let node = DepTaskNode {
id,
name: name.to_owned(),
priority,
estimated_cost: cost,
metadata: HashMap::new(),
};
self.nodes.insert(id, node);
self.edges.insert(id, Vec::new());
self.rev_edges.insert(id, Vec::new());
id
}
pub fn add_dependency(&mut self, task: TaskId, dep: TaskId) -> CoreResult<()> {
if !self.nodes.contains_key(&task) {
return Err(CoreError::InvalidInput(ErrorContext::new(format!(
"add_dependency: task {task} not found"
))));
}
if !self.nodes.contains_key(&dep) {
return Err(CoreError::InvalidInput(ErrorContext::new(format!(
"add_dependency: dep {dep} not found"
))));
}
if task == dep {
return Err(CoreError::InvalidInput(ErrorContext::new(format!(
"add_dependency: self-loop on task {task}"
))));
}
if self.config.enable_cycle_detection && self.is_reachable(dep, task) {
return Err(CoreError::InvalidInput(ErrorContext::new(format!(
"add_dependency: cycle detected — dep {dep} is already reachable from task {task}"
))));
}
let deps = self.edges.entry(task).or_default();
if !deps.contains(&dep) {
deps.push(dep);
}
let rev = self.rev_edges.entry(dep).or_default();
if !rev.contains(&task) {
rev.push(task);
}
Ok(())
}
pub fn n_tasks(&self) -> usize {
self.nodes.len()
}
pub fn n_edges(&self) -> usize {
self.edges.values().map(|v| v.len()).sum()
}
pub fn get_task(&self, id: TaskId) -> Option<&DepTaskNode> {
self.nodes.get(&id)
}
pub fn dependencies(&self, id: TaskId) -> &[TaskId] {
self.edges.get(&id).map(|v| v.as_slice()).unwrap_or(&[])
}
pub fn dependents(&self, id: TaskId) -> Vec<TaskId> {
self.rev_edges.get(&id).cloned().unwrap_or_default()
}
pub fn is_ready(&self, id: TaskId, completed: &HashSet<TaskId>) -> bool {
self.edges
.get(&id)
.map(|deps| deps.iter().all(|d| completed.contains(d)))
.unwrap_or(true)
}
pub fn topological_sort(&self) -> CoreResult<Vec<TaskId>> {
match self.config.topological_order {
TopologicalAlgorithm::Kahn => self.topological_sort_kahn(),
TopologicalAlgorithm::DfsBased => self.topological_sort_dfs(),
}
}
pub fn topological_sort_kahn(&self) -> CoreResult<Vec<TaskId>> {
let mut in_degree: HashMap<TaskId, usize> = self
.nodes
.keys()
.map(|&id| (id, self.edges[&id].len()))
.collect();
let mut ready: Vec<TaskId> = in_degree
.iter()
.filter_map(|(&id, °)| if deg == 0 { Some(id) } else { None })
.collect();
ready.sort_unstable_by(|&a, &b| {
let pa = self.nodes[&a].priority;
let pb = self.nodes[&b].priority;
pb.cmp(&pa).then(a.cmp(&b))
});
let mut order = Vec::with_capacity(self.nodes.len());
while !ready.is_empty() {
let id = ready.remove(0);
order.push(id);
let new_ready: Vec<TaskId> = if let Some(children) = self.rev_edges.get(&id) {
children
.iter()
.filter_map(|&child| {
let deg = in_degree.entry(child).or_insert(0);
if *deg > 0 {
*deg -= 1;
}
if *deg == 0 {
Some(child)
} else {
None
}
})
.collect()
} else {
Vec::new()
};
for nid in new_ready {
let pos = ready.partition_point(|&x| {
let px = self.nodes[&x].priority;
let pn = self.nodes[&nid].priority;
px > pn || (px == pn && x < nid)
});
ready.insert(pos, nid);
}
}
if order.len() != self.nodes.len() {
return Err(CoreError::InvalidInput(ErrorContext::new(
"topological_sort: cycle detected in graph",
)));
}
Ok(order)
}
pub fn topological_sort_dfs(&self) -> CoreResult<Vec<TaskId>> {
let mut color: HashMap<TaskId, u8> = self.nodes.keys().map(|&id| (id, 0u8)).collect();
let mut result: Vec<TaskId> = Vec::with_capacity(self.nodes.len());
let mut all_ids: Vec<TaskId> = self.nodes.keys().cloned().collect();
all_ids.sort_unstable();
let mut call_stack: Vec<(TaskId, usize)> = Vec::new();
for start in all_ids {
if color[&start] != 0 {
continue;
}
call_stack.push((start, 0));
*color.entry(start).or_insert(0) = 1;
while let Some(frame) = call_stack.last_mut() {
let (node, idx) = *frame;
let successors: Vec<TaskId> =
self.rev_edges.get(&node).cloned().unwrap_or_default();
if idx < successors.len() {
let child = successors[idx];
frame.1 += 1; match color[&child] {
1 => {
return Err(CoreError::InvalidInput(ErrorContext::new(
"topological_sort_dfs: cycle detected",
)));
}
0 => {
*color.entry(child).or_insert(0) = 1;
call_stack.push((child, 0));
}
_ => {} }
} else {
call_stack.pop();
*color.entry(node).or_insert(1) = 2;
result.push(node);
}
}
}
result.reverse();
Ok(result)
}
pub fn find_cycles(&self) -> Vec<Vec<TaskId>> {
let mut color: HashMap<TaskId, u8> = self.nodes.keys().map(|&id| (id, 0u8)).collect();
let mut cycles: Vec<Vec<TaskId>> = Vec::new();
let mut stack: Vec<TaskId> = Vec::new();
for &start in self.nodes.keys() {
if color[&start] != 0 {
continue;
}
self.dfs_find_cycles(start, &mut color, &mut stack, &mut cycles);
}
cycles
}
fn dfs_find_cycles(
&self,
node: TaskId,
color: &mut HashMap<TaskId, u8>,
stack: &mut Vec<TaskId>,
cycles: &mut Vec<Vec<TaskId>>,
) {
if stack.len() >= self.config.max_depth {
return;
}
*color.entry(node).or_insert(0) = 1; stack.push(node);
let deps: Vec<TaskId> = self.edges.get(&node).cloned().unwrap_or_default();
for dep in deps {
match *color.entry(dep).or_insert(0) {
1 => {
if let Some(pos) = stack.iter().position(|&x| x == dep) {
let cycle: Vec<TaskId> = stack[pos..].to_vec();
cycles.push(cycle);
}
}
0 => self.dfs_find_cycles(dep, color, stack, cycles),
_ => {}
}
}
stack.pop();
*color.entry(node).or_insert(1) = 2; }
pub fn critical_path(&self) -> Vec<TaskId> {
let order = match self.topological_sort() {
Ok(o) => o,
Err(_) => return Vec::new(),
};
let mut dist: HashMap<TaskId, f64> = HashMap::new();
let mut prev: HashMap<TaskId, Option<TaskId>> = HashMap::new();
for &id in &order {
let cost = self.nodes.get(&id).map(|n| n.estimated_cost).unwrap_or(1.0);
let max_pred_dist = self
.edges
.get(&id)
.map(|deps| {
deps.iter()
.filter_map(|d| dist.get(d).copied())
.fold(f64::NEG_INFINITY, f64::max)
})
.unwrap_or(f64::NEG_INFINITY);
let pred = if max_pred_dist.is_finite() {
self.edges.get(&id).and_then(|deps| {
deps.iter()
.max_by(|&&a, &&b| {
dist.get(&a)
.copied()
.unwrap_or(f64::NEG_INFINITY)
.partial_cmp(&dist.get(&b).copied().unwrap_or(f64::NEG_INFINITY))
.unwrap_or(std::cmp::Ordering::Equal)
})
.copied()
})
} else {
None
};
let d = if max_pred_dist.is_finite() {
max_pred_dist + cost
} else {
cost
};
dist.insert(id, d);
prev.insert(id, pred);
}
let sink = dist
.iter()
.max_by(|(_, &da), (_, &db)| da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal))
.map(|(&id, _)| id);
let mut path = Vec::new();
let mut current = sink;
while let Some(id) = current {
path.push(id);
current = prev.get(&id).and_then(|opt| *opt);
}
path.reverse();
path
}
pub fn execution_layers(&self) -> CoreResult<Vec<Vec<TaskId>>> {
let mut in_deg: HashMap<TaskId, usize> = self
.nodes
.keys()
.map(|&id| (id, self.edges[&id].len()))
.collect();
let mut current_layer: Vec<TaskId> = in_deg
.iter()
.filter_map(|(&id, °)| if deg == 0 { Some(id) } else { None })
.collect();
current_layer.sort_unstable();
let mut layers: Vec<Vec<TaskId>> = Vec::new();
let mut processed = 0usize;
while !current_layer.is_empty() {
layers.push(current_layer.clone());
processed += current_layer.len();
let mut next_layer: Vec<TaskId> = Vec::new();
for id in ¤t_layer {
if let Some(children) = self.rev_edges.get(id) {
for &child in children {
let deg = in_deg.entry(child).or_insert(0);
if *deg > 0 {
*deg -= 1;
}
if *deg == 0 {
next_layer.push(child);
}
}
}
}
next_layer.sort_unstable();
current_layer = next_layer;
}
if processed != self.nodes.len() {
return Err(CoreError::InvalidInput(ErrorContext::new(
"execution_layers: cycle detected",
)));
}
Ok(layers)
}
pub fn parallel_schedule(&self, n_workers: usize) -> CoreResult<Vec<Vec<TaskId>>> {
let layers = self.execution_layers()?;
let n_workers = n_workers.max(1);
let mut schedule: Vec<Vec<TaskId>> = vec![Vec::new(); n_workers];
let mut worker = 0usize;
for layer in &layers {
let mut sorted_layer = layer.clone();
sorted_layer.sort_unstable_by(|&a, &b| {
let ca = self.nodes.get(&a).map(|n| n.estimated_cost).unwrap_or(1.0);
let cb = self.nodes.get(&b).map(|n| n.estimated_cost).unwrap_or(1.0);
cb.partial_cmp(&ca).unwrap_or(std::cmp::Ordering::Equal)
});
for task_id in sorted_layer {
schedule[worker % n_workers].push(task_id);
worker += 1;
}
}
Ok(schedule)
}
fn is_reachable(&self, from: TaskId, target: TaskId) -> bool {
let mut visited: HashSet<TaskId> = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(from);
while let Some(cur) = queue.pop_front() {
if cur == target {
return true;
}
if !visited.insert(cur) {
continue;
}
if let Some(deps) = self.edges.get(&cur) {
for &dep in deps {
if !visited.contains(&dep) {
queue.push_back(dep);
}
}
}
}
false
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_config() -> DependencyGraphConfig {
DependencyGraphConfig::default()
}
fn build_chain() -> (DependencyGraph, TaskId, TaskId, TaskId) {
let mut g = DependencyGraph::new(make_config());
let a = g.add_task("A", 0);
let b = g.add_task("B", 0);
let c = g.add_task("C", 0);
g.add_dependency(b, a).expect("b depends on a");
g.add_dependency(c, b).expect("c depends on b");
(g, a, b, c)
}
#[test]
fn test_add_task_and_dependency_no_error() {
let mut g = DependencyGraph::new(make_config());
let a = g.add_task("a", 0);
let b = g.add_task("b", 0);
g.add_dependency(b, a).expect("valid DAG edge");
assert_eq!(g.n_tasks(), 2);
assert_eq!(g.n_edges(), 1);
}
#[test]
fn test_topological_sort_chain() {
let (g, a, b, c) = build_chain();
let order = g.topological_sort().expect("acyclic");
assert_eq!(order.len(), 3);
let pos_a = order.iter().position(|&x| x == a).expect("a in order");
let pos_b = order.iter().position(|&x| x == b).expect("b in order");
let pos_c = order.iter().position(|&x| x == c).expect("c in order");
assert!(pos_a < pos_b, "A must precede B");
assert!(pos_b < pos_c, "B must precede C");
}
#[test]
fn test_topological_sort_cycle_returns_err() {
let mut g = DependencyGraph::new(DependencyGraphConfig {
enable_cycle_detection: false,
..DependencyGraphConfig::default()
});
let a = g.add_task("a", 0);
let b = g.add_task("b", 0);
g.edges.get_mut(&a).expect("a edges").push(b);
g.edges.get_mut(&b).expect("b edges").push(a);
assert!(
g.topological_sort_kahn().is_err(),
"cycle must be detected by Kahn"
);
}
#[test]
fn test_add_dependency_cycle_rejected() {
let mut g = DependencyGraph::new(make_config());
let a = g.add_task("a", 0);
let b = g.add_task("b", 0);
g.add_dependency(b, a).expect("b → a");
assert!(g.add_dependency(a, b).is_err(), "cycle must be rejected");
}
#[test]
fn test_find_cycles_returns_cycle() {
let mut g = DependencyGraph::new(DependencyGraphConfig {
enable_cycle_detection: false,
..DependencyGraphConfig::default()
});
let a = g.add_task("a", 0);
let b = g.add_task("b", 0);
g.edges.get_mut(&a).expect("a").push(b);
g.edges.get_mut(&b).expect("b").push(a);
let cycles = g.find_cycles();
assert!(!cycles.is_empty(), "should find at least one cycle");
}
#[test]
fn test_execution_layers_independent_tasks_in_layer_0() {
let mut g = DependencyGraph::new(make_config());
g.add_task("x", 0);
g.add_task("y", 0);
g.add_task("z", 0);
let layers = g.execution_layers().expect("acyclic");
assert_eq!(layers.len(), 1, "all independent tasks in one layer");
assert_eq!(layers[0].len(), 3);
}
#[test]
fn test_execution_layers_chain() {
let (g, _a, _b, _c) = build_chain();
let layers = g.execution_layers().expect("acyclic");
assert_eq!(layers.len(), 3, "chain has 3 layers");
assert_eq!(layers[0].len(), 1); assert_eq!(layers[1].len(), 1); assert_eq!(layers[2].len(), 1); }
#[test]
fn test_critical_path_selects_longest_cost_path() {
let mut g = DependencyGraph::new(make_config());
let source = g.add_task_with_cost("source", 0, 1.0);
let cheap = g.add_task_with_cost("cheap", 0, 1.0);
let expensive = g.add_task_with_cost("expensive", 0, 10.0);
let sink = g.add_task_with_cost("sink", 0, 1.0);
g.add_dependency(cheap, source).expect("cheap dep");
g.add_dependency(expensive, source).expect("expensive dep");
g.add_dependency(sink, cheap).expect("sink dep cheap");
g.add_dependency(sink, expensive)
.expect("sink dep expensive");
let path = g.critical_path();
assert!(!path.is_empty(), "critical path should be non-empty");
assert!(
path.contains(&expensive),
"critical path must go through 'expensive' node"
);
}
#[test]
fn test_parallel_schedule_all_tasks_covered() {
let (g, _a, _b, _c) = build_chain();
let schedule = g.parallel_schedule(2).expect("valid schedule");
let all_tasks: HashSet<TaskId> = schedule.into_iter().flatten().collect();
assert_eq!(all_tasks.len(), 3, "all tasks must be in schedule");
}
#[test]
fn test_dependency_graph_config_default() {
let cfg = DependencyGraphConfig::default();
assert!(cfg.enable_cycle_detection);
assert_eq!(cfg.max_depth, 1000);
assert_eq!(cfg.topological_order, TopologicalAlgorithm::Kahn);
}
#[test]
fn test_is_ready_task_with_all_deps_complete() {
let mut g = DependencyGraph::new(make_config());
let a = g.add_task("a", 0);
let b = g.add_task("b", 0);
g.add_dependency(b, a).expect("b dep a");
let completed: HashSet<TaskId> = [a].into();
assert!(g.is_ready(b, &completed), "b is ready when a is complete");
let empty: HashSet<TaskId> = HashSet::new();
assert!(!g.is_ready(b, &empty), "b not ready when a is incomplete");
}
#[test]
fn test_topological_sort_dfs_chain() {
let (g, a, b, c) = build_chain();
let order = g.topological_sort_dfs().expect("acyclic DFS");
let pos_a = order.iter().position(|&x| x == a).expect("a in order");
let pos_b = order.iter().position(|&x| x == b).expect("b in order");
let pos_c = order.iter().position(|&x| x == c).expect("c in order");
assert!(pos_a < pos_b, "DFS: A must precede B");
assert!(pos_b < pos_c, "DFS: B must precede C");
}
#[test]
fn test_dependencies_and_dependents() {
let mut g = DependencyGraph::new(make_config());
let a = g.add_task("a", 0);
let b = g.add_task("b", 0);
g.add_dependency(b, a).expect("b dep a");
assert_eq!(g.dependencies(b), &[a]);
assert_eq!(g.dependents(a), vec![b]);
}
#[test]
fn test_get_task_metadata() {
let mut g = DependencyGraph::new(make_config());
let id = g.add_task_with_cost("my_task", 5, 42.0);
let node = g.get_task(id).expect("task should exist");
assert_eq!(node.name, "my_task");
assert_eq!(node.priority, 5);
assert!((node.estimated_cost - 42.0).abs() < f64::EPSILON);
}
}