use std::collections::{BTreeSet, HashMap, HashSet, VecDeque};
use serde::Serialize;
#[derive(Debug, Clone, Default, Serialize, schemars::JsonSchema)]
pub struct Dag {
dependencies: HashMap<String, Vec<String>>,
}
impl Dag {
pub fn new() -> Self {
Self::default()
}
pub fn add_node(&mut self, key: &str) {
self.dependencies.entry(key.to_string()).or_default();
}
pub fn add_edge(&mut self, from: &str, depends_on: &str) {
self.dependencies
.entry(from.to_string())
.or_default()
.push(depends_on.to_string());
self.dependencies.entry(depends_on.to_string()).or_default();
}
pub fn nodes(&self) -> impl Iterator<Item = &String> {
self.dependencies.keys()
}
pub fn dependencies_of(&self, key: &str) -> &[String] {
self.dependencies.get(key).map_or(&[], |v| v.as_slice())
}
pub fn dependents(&self) -> HashMap<String, Vec<String>> {
let mut rev: HashMap<String, Vec<String>> = HashMap::new();
for (node, deps) in &self.dependencies {
rev.entry(node.clone()).or_default();
for dep in deps {
rev.entry(dep.clone()).or_default().push(node.clone());
}
}
rev
}
}
pub fn toposort(dag: &Dag) -> Result<Vec<String>, Vec<String>> {
let mut remaining: HashMap<String, usize> = HashMap::new();
for node in dag.nodes() {
remaining.insert(node.clone(), dag.dependencies_of(node).len());
}
let dependents = dag.dependents();
let mut ready: Vec<String> = remaining
.iter()
.filter(|(_, &count)| count == 0)
.map(|(k, _)| k.clone())
.collect();
ready.sort();
let mut queue: VecDeque<String> = ready.into_iter().collect();
let mut order: Vec<String> = Vec::with_capacity(remaining.len());
while let Some(node) = queue.pop_front() {
order.push(node.clone());
if let Some(deps) = dependents.get(&node) {
let mut newly_ready: Vec<String> = Vec::new();
for dependent in deps {
if let Some(count) = remaining.get_mut(dependent) {
*count = count.saturating_sub(1);
if *count == 0 {
newly_ready.push(dependent.clone());
}
}
}
newly_ready.sort();
for n in newly_ready {
queue.push_back(n);
}
}
}
if order.len() == remaining.len() {
Ok(order)
} else {
let mut residual: Vec<String> = remaining
.iter()
.filter(|(_, &count)| count > 0)
.map(|(k, _)| k.clone())
.collect();
residual.sort();
Err(residual)
}
}
pub fn upstream_input_leaves(
dag: &Dag,
output_cell: &str,
input_cells: &HashSet<String>,
) -> BTreeSet<String> {
let mut seen: HashSet<String> = HashSet::new();
let mut leaves: BTreeSet<String> = BTreeSet::new();
let mut stack: Vec<String> = vec![output_cell.to_string()];
while let Some(cell) = stack.pop() {
if !seen.insert(cell.clone()) {
continue;
}
if input_cells.contains(&cell) {
leaves.insert(cell);
continue;
}
for dep in dag.dependencies_of(&cell) {
stack.push(dep.clone());
}
}
leaves
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trips_through_serde() {
let mut dag = Dag::new();
dag.add_edge("S!C1", "S!A1");
let v = serde_json::to_value(&dag).expect("serialize Dag");
assert_eq!(v["dependencies"]["S!C1"][0], "S!A1");
}
#[test]
fn dependencies_of_returns_stored_keys() {
let mut dag = Dag::new();
dag.add_edge("S!C1", "S!A1");
dag.add_edge("S!C1", "S!B1");
assert_eq!(
dag.dependencies_of("S!C1"),
&["S!A1".to_string(), "S!B1".to_string()]
);
assert!(dag.dependencies_of("S!Z9").is_empty());
}
#[test]
fn dependents_yields_the_reverse_map_kahn_needs() {
let mut dag = Dag::new();
dag.add_edge("S!C", "S!A");
dag.add_edge("S!C", "S!B");
let dependents = dag.dependents();
assert_eq!(dependents.get("S!A"), Some(&vec!["S!C".to_string()]));
assert_eq!(dependents.get("S!B"), Some(&vec!["S!C".to_string()]));
assert_eq!(dependents.get("S!C"), Some(&Vec::<String>::new()));
}
#[test]
fn add_node_registers_a_zero_dependency_node() {
let mut dag = Dag::new();
dag.add_node("S!A1");
assert_eq!(dag.nodes().count(), 1);
assert!(dag.dependencies_of("S!A1").is_empty());
}
#[test]
fn toposort_orders_dependencies_before_dependents() {
let mut dag = Dag::new();
dag.add_node("S!A");
dag.add_node("S!B");
dag.add_edge("S!C", "S!A");
dag.add_edge("S!C", "S!B");
let order = toposort(&dag).expect("acyclic");
let pos = |k: &str| order.iter().position(|n| n == k).unwrap();
assert!(pos("S!A") < pos("S!C"));
assert!(pos("S!B") < pos("S!C"));
}
#[test]
fn toposort_returns_residual_on_a_cycle() {
let mut dag = Dag::new();
dag.add_edge("S!A", "S!B");
dag.add_edge("S!B", "S!A");
let residual = toposort(&dag).expect_err("a cycle must be Err");
assert_eq!(residual, vec!["S!A".to_string(), "S!B".to_string()]);
}
fn inputs(keys: &[&str]) -> HashSet<String> {
keys.iter().map(|k| (*k).to_string()).collect()
}
fn leaves(set: &BTreeSet<String>) -> Vec<String> {
set.iter().cloned().collect()
}
#[test]
fn upstream_input_leaves_returns_exactly_reachable_inputs() {
let mut dag = Dag::new();
dag.add_edge("Calc!out", "Calc!f1");
dag.add_edge("Calc!f1", "In!income");
dag.add_edge("Calc!f1", "In!filing");
let input_cells = inputs(&["In!income", "In!filing", "In!withheld"]);
let got = upstream_input_leaves(&dag, "Calc!out", &input_cells);
assert_eq!(leaves(&got), vec!["In!filing", "In!income"]);
}
#[test]
fn upstream_input_leaves_excludes_constant_only_path() {
let mut dag = Dag::new();
dag.add_edge("Calc!out", "In!income");
dag.add_edge("Calc!out", "Const!rate"); dag.add_edge("Const!rate", "Const!base"); let input_cells = inputs(&["In!income"]);
let got = upstream_input_leaves(&dag, "Calc!out", &input_cells);
assert_eq!(
leaves(&got),
vec!["In!income"],
"a constant-only upstream path yields no input leaf"
);
}
#[test]
fn upstream_input_leaves_input_is_a_leaf_traversal_stops() {
let mut dag = Dag::new();
dag.add_edge("Calc!out", "In!income");
dag.add_edge("In!income", "Const!hidden"); let input_cells = inputs(&["In!income"]);
let got = upstream_input_leaves(&dag, "Calc!out", &input_cells);
assert_eq!(leaves(&got), vec!["In!income"]);
}
#[test]
fn upstream_input_leaves_shared_intermediate_unions_per_output() {
let mut dag = Dag::new();
dag.add_edge("Calc!tax", "Calc!shared");
dag.add_edge("Calc!tax", "In!filing");
dag.add_edge("Calc!refund", "Calc!shared");
dag.add_edge("Calc!refund", "In!withheld");
dag.add_edge("Calc!shared", "In!income");
let input_cells = inputs(&["In!income", "In!filing", "In!withheld"]);
let tax = upstream_input_leaves(&dag, "Calc!tax", &input_cells);
assert_eq!(
leaves(&tax),
vec!["In!filing", "In!income"],
"tax = its own upstream leaves (income via shared + filing)"
);
let refund = upstream_input_leaves(&dag, "Calc!refund", &input_cells);
assert_eq!(
leaves(&refund),
vec!["In!income", "In!withheld"],
"refund = its own upstream leaves (income via shared + withheld)"
);
}
#[test]
fn upstream_input_leaves_terminates_on_a_cycle() {
let mut dag = Dag::new();
dag.add_edge("Calc!out", "Calc!a");
dag.add_edge("Calc!a", "Calc!b");
dag.add_edge("Calc!b", "Calc!a"); dag.add_edge("Calc!a", "In!income");
let input_cells = inputs(&["In!income"]);
let got = upstream_input_leaves(&dag, "Calc!out", &input_cells);
assert_eq!(leaves(&got), vec!["In!income"]);
}
proptest::proptest! {
#[test]
fn prop_upstream_leaves_subset_and_reachable(
edges in proptest::collection::vec(
(0usize..12, 0usize..12),
0..40,
),
input_mask in proptest::collection::vec(proptest::bool::ANY, 12),
) {
let node = |i: usize| format!("N{i}");
let mut dag = Dag::new();
for i in 0..12 {
dag.add_node(&node(i));
}
for (from, dep) in &edges {
if dep < from {
dag.add_edge(&node(*from), &node(*dep));
}
}
let input_cells: HashSet<String> = (0..12)
.filter(|i| input_mask[*i])
.map(node)
.collect();
let output = node(11);
let got = upstream_input_leaves(&dag, &output, &input_cells);
for leaf in &got {
proptest::prop_assert!(
input_cells.contains(leaf),
"derived leaf {leaf} must be an input cell"
);
}
for leaf in &got {
let mut seen: HashSet<String> = HashSet::new();
let mut stack = vec![output.clone()];
let mut reached = false;
while let Some(c) = stack.pop() {
if c == *leaf {
reached = true;
break;
}
if !seen.insert(c.clone()) {
continue;
}
for d in dag.dependencies_of(&c) {
stack.push(d.clone());
}
}
proptest::prop_assert!(
reached,
"derived leaf {leaf} must be reachable upstream of {output}"
);
}
}
}
}