use rustc_hash::FxHashSet;
use smallvec::SmallVec;
use crate::ast::analyzed::{AnalyzedWorkflow, TaskId, TaskTable};
use crate::error::NikaError;
pub(crate) type DepVec = SmallVec<[TaskId; 4]>;
#[derive(Debug, Clone)]
pub struct TopoOrder {
order: Box<[TaskId]>,
depths: Box<[u32]>,
}
impl TopoOrder {
pub fn order(&self) -> &[TaskId] {
&self.order
}
pub fn depth(&self, id: TaskId) -> u32 {
self.depths[id.index() as usize]
}
pub fn max_depth(&self) -> u32 {
self.depths.iter().copied().max().unwrap_or(0)
}
}
#[derive(Debug, Clone)]
pub struct IndexedDag {
successors: Vec<DepVec>,
predecessors: Vec<DepVec>,
topo: TopoOrder,
num_tasks: usize,
}
impl IndexedDag {
pub fn from_analyzed(wf: &AnalyzedWorkflow) -> Result<Self, NikaError> {
let n = wf.tasks.len();
let mut successors: Vec<DepVec> = vec![DepVec::new(); n];
let mut predecessors: Vec<DepVec> = vec![DepVec::new(); n];
let mut in_degree = vec![0u32; n];
for task in &wf.tasks {
let idx = task.id.index() as usize;
let mut seen_deps: FxHashSet<TaskId> = FxHashSet::default();
for &dep_id in task.depends_on.iter().chain(task.implicit_deps.iter()) {
if !seen_deps.insert(dep_id) {
continue; }
let dep_idx = dep_id.index() as usize;
successors[dep_idx].push(task.id);
predecessors[idx].push(dep_id);
in_degree[idx] += 1;
}
}
let topo = kahn_sort(&successors, &mut in_degree, n, &wf.task_table)?;
Ok(Self {
successors,
predecessors,
topo,
num_tasks: n,
})
}
pub fn dependencies(&self, id: TaskId) -> &[TaskId] {
&self.predecessors[id.index() as usize]
}
pub fn successors(&self, id: TaskId) -> &[TaskId] {
&self.successors[id.index() as usize]
}
pub fn topo_order(&self) -> &[TaskId] {
self.topo.order()
}
pub fn depth(&self, id: TaskId) -> u32 {
self.topo.depth(id)
}
pub fn max_depth(&self) -> u32 {
self.topo.max_depth()
}
pub fn final_tasks(&self) -> Vec<TaskId> {
self.successors
.iter()
.enumerate()
.filter(|(_, succs)| succs.is_empty())
.map(|(i, _)| TaskId::new(i as u32))
.collect()
}
pub fn len(&self) -> usize {
self.num_tasks
}
pub fn is_empty(&self) -> bool {
self.num_tasks == 0
}
pub fn all_deps_done(&self, id: TaskId, done: &[bool]) -> bool {
self.predecessors[id.index() as usize]
.iter()
.all(|dep| done[dep.index() as usize])
}
pub fn root_tasks(&self) -> Vec<TaskId> {
self.predecessors
.iter()
.enumerate()
.filter(|(_, preds)| preds.is_empty())
.map(|(i, _)| TaskId::new(i as u32))
.collect()
}
pub fn topo(&self) -> &TopoOrder {
&self.topo
}
}
fn kahn_sort(
successors: &[DepVec],
in_degree: &mut [u32],
n: usize,
task_table: &TaskTable,
) -> Result<TopoOrder, NikaError> {
let mut order = Vec::with_capacity(n);
let mut depths = vec![0u32; n];
let mut queue: std::collections::VecDeque<TaskId> = in_degree
.iter()
.enumerate()
.filter(|(_, °)| deg == 0)
.map(|(i, _)| TaskId::new(i as u32))
.collect();
while let Some(node) = queue.pop_front() {
order.push(node);
let node_depth = depths[node.index() as usize];
for &succ in &successors[node.index() as usize] {
let succ_idx = succ.index() as usize;
depths[succ_idx] = depths[succ_idx].max(node_depth + 1);
in_degree[succ_idx] -= 1;
if in_degree[succ_idx] == 0 {
queue.push_back(succ);
}
}
}
if order.len() != n {
let cycle_tasks: Vec<String> = in_degree
.iter()
.enumerate()
.filter(|(_, °)| deg > 0)
.filter_map(|(i, _)| task_table.get_name(TaskId::new(i as u32)))
.map(|s| s.to_string())
.collect();
return Err(NikaError::CycleDetected {
cycle: cycle_tasks.join(" → "),
});
}
Ok(TopoOrder {
order: order.into_boxed_slice(),
depths: depths.into_boxed_slice(),
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::analyzed::{AnalyzedTask, AnalyzedTaskAction, AnalyzedWorkflow};
use crate::binding::WithSpec;
use crate::source::Span;
fn build_workflow(specs: &[(&str, &[&str])]) -> AnalyzedWorkflow {
let mut wf = AnalyzedWorkflow::default();
for (name, _) in specs {
wf.task_table.insert(name);
}
for (name, deps) in specs {
let id = wf.task_table.get_id(name).unwrap();
let depends_on: Vec<TaskId> = deps
.iter()
.map(|d| wf.task_table.get_id(d).unwrap())
.collect();
wf.tasks.push(AnalyzedTask {
id,
name: name.to_string(),
description: None,
action: AnalyzedTaskAction::default(),
provider: None,
model: None,
with_spec: WithSpec::default(),
depends_on,
implicit_deps: Vec::new(),
output: None,
for_each: None,
retry: None,
decompose: None,
concurrency: None,
fail_fast: None,
artifact: None,
log: None,
structured: None,
span: Span::dummy(),
});
}
wf
}
#[test]
fn empty_dag() {
let wf = build_workflow(&[]);
let dag = IndexedDag::from_analyzed(&wf).unwrap();
assert_eq!(dag.len(), 0);
assert!(dag.is_empty());
assert!(dag.topo_order().is_empty());
assert!(dag.final_tasks().is_empty());
assert!(dag.root_tasks().is_empty());
}
#[test]
fn single_task() {
let wf = build_workflow(&[("solo", &[])]);
let dag = IndexedDag::from_analyzed(&wf).unwrap();
assert_eq!(dag.len(), 1);
assert!(!dag.is_empty());
assert_eq!(dag.topo_order().len(), 1);
assert_eq!(dag.depth(TaskId::new(0)), 0);
assert_eq!(dag.max_depth(), 0);
assert_eq!(dag.final_tasks().len(), 1);
assert_eq!(dag.root_tasks().len(), 1);
assert!(dag.dependencies(TaskId::new(0)).is_empty());
assert!(dag.successors(TaskId::new(0)).is_empty());
}
#[test]
fn linear_chain() {
let wf = build_workflow(&[("a", &[]), ("b", &["a"]), ("c", &["b"])]);
let dag = IndexedDag::from_analyzed(&wf).unwrap();
assert_eq!(dag.len(), 3);
let order = dag.topo_order();
let pos_a = order.iter().position(|&id| id == TaskId::new(0)).unwrap();
let pos_b = order.iter().position(|&id| id == TaskId::new(1)).unwrap();
let pos_c = order.iter().position(|&id| id == TaskId::new(2)).unwrap();
assert!(pos_a < pos_b);
assert!(pos_b < pos_c);
assert_eq!(dag.depth(TaskId::new(0)), 0); assert_eq!(dag.depth(TaskId::new(1)), 1); assert_eq!(dag.depth(TaskId::new(2)), 2); assert_eq!(dag.max_depth(), 2);
assert!(dag.dependencies(TaskId::new(0)).is_empty());
assert_eq!(dag.successors(TaskId::new(0)), &[TaskId::new(1)]);
assert_eq!(dag.dependencies(TaskId::new(1)), &[TaskId::new(0)]);
assert_eq!(dag.successors(TaskId::new(1)), &[TaskId::new(2)]);
assert_eq!(dag.dependencies(TaskId::new(2)), &[TaskId::new(1)]);
assert!(dag.successors(TaskId::new(2)).is_empty());
assert_eq!(dag.final_tasks(), vec![TaskId::new(2)]);
assert_eq!(dag.root_tasks(), vec![TaskId::new(0)]);
}
#[test]
fn diamond_dag() {
let wf = build_workflow(&[("a", &[]), ("b", &["a"]), ("c", &["a"]), ("d", &["b", "c"])]);
let dag = IndexedDag::from_analyzed(&wf).unwrap();
assert_eq!(dag.len(), 4);
assert_eq!(dag.depth(TaskId::new(0)), 0); assert_eq!(dag.depth(TaskId::new(1)), 1); assert_eq!(dag.depth(TaskId::new(2)), 1); assert_eq!(dag.depth(TaskId::new(3)), 2); assert_eq!(dag.max_depth(), 2);
let d_deps = dag.dependencies(TaskId::new(3));
assert_eq!(d_deps.len(), 2);
assert!(d_deps.contains(&TaskId::new(1)));
assert!(d_deps.contains(&TaskId::new(2)));
let a_succs = dag.successors(TaskId::new(0));
assert_eq!(a_succs.len(), 2);
assert_eq!(dag.root_tasks(), vec![TaskId::new(0)]);
assert_eq!(dag.final_tasks(), vec![TaskId::new(3)]);
let order = dag.topo_order();
let pos = |id: u32| order.iter().position(|&x| x == TaskId::new(id)).unwrap();
assert!(pos(0) < pos(1));
assert!(pos(0) < pos(2));
assert!(pos(1) < pos(3));
assert!(pos(2) < pos(3));
}
#[test]
fn parallel_independent() {
let wf = build_workflow(&[("x", &[]), ("y", &[]), ("z", &[])]);
let dag = IndexedDag::from_analyzed(&wf).unwrap();
assert_eq!(dag.len(), 3);
assert_eq!(dag.topo_order().len(), 3);
for i in 0..3 {
assert_eq!(dag.depth(TaskId::new(i)), 0);
}
assert_eq!(dag.max_depth(), 0);
assert_eq!(dag.root_tasks().len(), 3);
assert_eq!(dag.final_tasks().len(), 3);
}
#[test]
fn cycle_detected() {
let mut wf = AnalyzedWorkflow::default();
wf.task_table.insert("a");
wf.task_table.insert("b");
wf.task_table.insert("c");
let id_a = wf.task_table.get_id("a").unwrap();
let id_b = wf.task_table.get_id("b").unwrap();
let id_c = wf.task_table.get_id("c").unwrap();
let make_task = |id: TaskId, name: &str, deps: Vec<TaskId>| AnalyzedTask {
id,
name: name.to_string(),
description: None,
action: AnalyzedTaskAction::default(),
provider: None,
model: None,
with_spec: WithSpec::default(),
depends_on: deps,
implicit_deps: Vec::new(),
output: None,
for_each: None,
retry: None,
decompose: None,
concurrency: None,
fail_fast: None,
artifact: None,
log: None,
structured: None,
span: Span::dummy(),
};
wf.tasks.push(make_task(id_a, "a", vec![id_c])); wf.tasks.push(make_task(id_b, "b", vec![id_a])); wf.tasks.push(make_task(id_c, "c", vec![id_b]));
let err = IndexedDag::from_analyzed(&wf).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("a") || msg.contains("b") || msg.contains("c"));
}
#[test]
fn all_deps_done_checks() {
let wf = build_workflow(&[("a", &[]), ("b", &["a"]), ("c", &["a", "b"])]);
let dag = IndexedDag::from_analyzed(&wf).unwrap();
let mut done = vec![false; 3];
assert!(dag.all_deps_done(TaskId::new(0), &done));
assert!(!dag.all_deps_done(TaskId::new(1), &done));
done[0] = true;
assert!(dag.all_deps_done(TaskId::new(1), &done));
assert!(!dag.all_deps_done(TaskId::new(2), &done));
done[1] = true;
assert!(dag.all_deps_done(TaskId::new(2), &done));
}
#[test]
fn implicit_deps_included() {
let mut wf = AnalyzedWorkflow::default();
wf.task_table.insert("src");
wf.task_table.insert("sink");
let id_src = wf.task_table.get_id("src").unwrap();
let id_sink = wf.task_table.get_id("sink").unwrap();
wf.tasks.push(AnalyzedTask {
id: id_src,
name: "src".to_string(),
description: None,
action: AnalyzedTaskAction::default(),
provider: None,
model: None,
with_spec: WithSpec::default(),
depends_on: Vec::new(),
implicit_deps: Vec::new(),
output: None,
for_each: None,
retry: None,
decompose: None,
concurrency: None,
fail_fast: None,
artifact: None,
log: None,
structured: None,
span: Span::dummy(),
});
wf.tasks.push(AnalyzedTask {
id: id_sink,
name: "sink".to_string(),
description: None,
action: AnalyzedTaskAction::default(),
provider: None,
model: None,
with_spec: WithSpec::default(),
depends_on: Vec::new(),
implicit_deps: vec![id_src], output: None,
for_each: None,
retry: None,
decompose: None,
concurrency: None,
fail_fast: None,
artifact: None,
log: None,
structured: None,
span: Span::dummy(),
});
let dag = IndexedDag::from_analyzed(&wf).unwrap();
assert_eq!(dag.dependencies(id_sink), &[id_src]);
assert_eq!(dag.successors(id_src), &[id_sink]);
assert_eq!(dag.depth(id_src), 0);
assert_eq!(dag.depth(id_sink), 1);
}
#[test]
fn wide_fanout() {
let mut specs: Vec<(&str, &[&str])> = Vec::new();
let names = ["r0", "r1", "r2", "r3", "r4", "r5"];
for name in &names {
specs.push((name, &[]));
}
specs.push(("sink", &["r0", "r1", "r2", "r3", "r4", "r5"]));
let wf = build_workflow(&specs);
let dag = IndexedDag::from_analyzed(&wf).unwrap();
assert_eq!(dag.len(), 7);
assert_eq!(dag.dependencies(TaskId::new(6)).len(), 6); assert_eq!(dag.root_tasks().len(), 6);
assert_eq!(dag.final_tasks(), vec![TaskId::new(6)]);
assert_eq!(dag.depth(TaskId::new(6)), 1);
}
#[test]
fn depth_complex() {
let wf = build_workflow(&[
("a", &[]),
("b", &["a"]),
("c", &["a"]),
("d", &["b"]),
("e", &["c", "b"]),
("f", &["e"]),
]);
let dag = IndexedDag::from_analyzed(&wf).unwrap();
assert_eq!(dag.depth(TaskId::new(0)), 0); assert_eq!(dag.depth(TaskId::new(1)), 1); assert_eq!(dag.depth(TaskId::new(2)), 1); assert_eq!(dag.depth(TaskId::new(3)), 2); assert_eq!(dag.depth(TaskId::new(4)), 2); assert_eq!(dag.depth(TaskId::new(5)), 3); assert_eq!(dag.max_depth(), 3);
}
#[test]
fn topo_order_valid() {
let wf = build_workflow(&[
("a", &[]),
("b", &["a"]),
("c", &["a"]),
("d", &["b", "c"]),
("e", &["d"]),
]);
let dag = IndexedDag::from_analyzed(&wf).unwrap();
let order = dag.topo_order();
assert_eq!(order.len(), 5);
for (pos, &task_id) in order.iter().enumerate() {
for &dep in dag.dependencies(task_id) {
let dep_pos = order.iter().position(|&x| x == dep).unwrap();
assert!(
dep_pos < pos,
"dep {:?} should appear before {:?}",
dep,
task_id
);
}
}
}
#[test]
fn partial_cycle() {
let mut wf = AnalyzedWorkflow::default();
wf.task_table.insert("a");
wf.task_table.insert("b");
wf.task_table.insert("c");
let id_a = wf.task_table.get_id("a").unwrap();
let id_b = wf.task_table.get_id("b").unwrap();
let id_c = wf.task_table.get_id("c").unwrap();
let make_task = |id: TaskId, name: &str, deps: Vec<TaskId>| AnalyzedTask {
id,
name: name.to_string(),
description: None,
action: AnalyzedTaskAction::default(),
provider: None,
model: None,
with_spec: WithSpec::default(),
depends_on: deps,
implicit_deps: Vec::new(),
output: None,
for_each: None,
retry: None,
decompose: None,
concurrency: None,
fail_fast: None,
artifact: None,
log: None,
structured: None,
span: Span::dummy(),
};
wf.tasks.push(make_task(id_a, "a", vec![]));
wf.tasks.push(make_task(id_b, "b", vec![id_a, id_c])); wf.tasks.push(make_task(id_c, "c", vec![id_b]));
let err = IndexedDag::from_analyzed(&wf).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("b") && msg.contains("c"));
}
#[test]
fn duplicate_dep_in_depends_on_and_implicit_deps_no_false_cycle() {
let mut wf = AnalyzedWorkflow::default();
wf.task_table.insert("src");
wf.task_table.insert("sink");
let id_src = wf.task_table.get_id("src").unwrap();
let id_sink = wf.task_table.get_id("sink").unwrap();
wf.tasks.push(AnalyzedTask {
id: id_src,
name: "src".to_string(),
description: None,
action: AnalyzedTaskAction::default(),
provider: None,
model: None,
with_spec: WithSpec::default(),
depends_on: Vec::new(),
implicit_deps: Vec::new(),
output: None,
for_each: None,
retry: None,
decompose: None,
concurrency: None,
fail_fast: None,
artifact: None,
log: None,
structured: None,
span: Span::dummy(),
});
wf.tasks.push(AnalyzedTask {
id: id_sink,
name: "sink".to_string(),
description: None,
action: AnalyzedTaskAction::default(),
provider: None,
model: None,
with_spec: WithSpec::default(),
depends_on: vec![id_src],
implicit_deps: vec![id_src], output: None,
for_each: None,
retry: None,
decompose: None,
concurrency: None,
fail_fast: None,
artifact: None,
log: None,
structured: None,
span: Span::dummy(),
});
let dag = IndexedDag::from_analyzed(&wf)
.expect("Duplicate dep in depends_on + implicit_deps should NOT cause false cycle");
assert_eq!(dag.dependencies(id_sink).len(), 1);
assert_eq!(dag.dependencies(id_sink), &[id_src]);
assert_eq!(dag.depth(id_src), 0);
assert_eq!(dag.depth(id_sink), 1);
}
}