1use std::collections::{BTreeSet, HashMap, HashSet, VecDeque};
19
20use serde::Serialize;
21
22#[derive(Debug, Clone, Default, Serialize, schemars::JsonSchema)]
29pub struct Dag {
30 dependencies: HashMap<String, Vec<String>>,
32}
33
34impl Dag {
35 pub fn new() -> Self {
37 Self::default()
38 }
39
40 pub fn add_node(&mut self, key: &str) {
43 self.dependencies.entry(key.to_string()).or_default();
44 }
45
46 pub fn add_edge(&mut self, from: &str, depends_on: &str) {
50 self.dependencies
51 .entry(from.to_string())
52 .or_default()
53 .push(depends_on.to_string());
54 self.dependencies.entry(depends_on.to_string()).or_default();
56 }
57
58 pub fn nodes(&self) -> impl Iterator<Item = &String> {
60 self.dependencies.keys()
61 }
62
63 pub fn dependencies_of(&self, key: &str) -> &[String] {
66 self.dependencies.get(key).map_or(&[], |v| v.as_slice())
67 }
68
69 pub fn dependents(&self) -> HashMap<String, Vec<String>> {
75 let mut rev: HashMap<String, Vec<String>> = HashMap::new();
76 for (node, deps) in &self.dependencies {
77 rev.entry(node.clone()).or_default();
78 for dep in deps {
79 rev.entry(dep.clone()).or_default().push(node.clone());
80 }
81 }
82 rev
83 }
84}
85
86pub fn toposort(dag: &Dag) -> Result<Vec<String>, Vec<String>> {
92 let mut remaining: HashMap<String, usize> = HashMap::new();
94 for node in dag.nodes() {
95 remaining.insert(node.clone(), dag.dependencies_of(node).len());
96 }
97
98 let dependents = dag.dependents();
99
100 let mut ready: Vec<String> = remaining
103 .iter()
104 .filter(|(_, &count)| count == 0)
105 .map(|(k, _)| k.clone())
106 .collect();
107 ready.sort();
108 let mut queue: VecDeque<String> = ready.into_iter().collect();
109
110 let mut order: Vec<String> = Vec::with_capacity(remaining.len());
111 while let Some(node) = queue.pop_front() {
112 order.push(node.clone());
113 if let Some(deps) = dependents.get(&node) {
115 let mut newly_ready: Vec<String> = Vec::new();
117 for dependent in deps {
118 if let Some(count) = remaining.get_mut(dependent) {
119 *count = count.saturating_sub(1);
120 if *count == 0 {
121 newly_ready.push(dependent.clone());
122 }
123 }
124 }
125 newly_ready.sort();
126 for n in newly_ready {
127 queue.push_back(n);
128 }
129 }
130 }
131
132 if order.len() == remaining.len() {
135 Ok(order)
136 } else {
137 let mut residual: Vec<String> = remaining
139 .iter()
140 .filter(|(_, &count)| count > 0)
141 .map(|(k, _)| k.clone())
142 .collect();
143 residual.sort();
144 Err(residual)
145 }
146}
147
148pub fn upstream_input_leaves(
167 dag: &Dag,
168 output_cell: &str,
169 input_cells: &HashSet<String>,
170) -> BTreeSet<String> {
171 let mut seen: HashSet<String> = HashSet::new();
172 let mut leaves: BTreeSet<String> = BTreeSet::new();
173 let mut stack: Vec<String> = vec![output_cell.to_string()];
174 while let Some(cell) = stack.pop() {
175 if !seen.insert(cell.clone()) {
177 continue;
178 }
179 if input_cells.contains(&cell) {
180 leaves.insert(cell);
183 continue;
184 }
185 for dep in dag.dependencies_of(&cell) {
187 stack.push(dep.clone());
188 }
189 }
190 leaves
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn round_trips_through_serde() {
199 let mut dag = Dag::new();
200 dag.add_edge("S!C1", "S!A1");
201 let v = serde_json::to_value(&dag).expect("serialize Dag");
202 assert_eq!(v["dependencies"]["S!C1"][0], "S!A1");
203 }
204
205 #[test]
206 fn dependencies_of_returns_stored_keys() {
207 let mut dag = Dag::new();
208 dag.add_edge("S!C1", "S!A1");
209 dag.add_edge("S!C1", "S!B1");
210 assert_eq!(
211 dag.dependencies_of("S!C1"),
212 &["S!A1".to_string(), "S!B1".to_string()]
213 );
214 assert!(dag.dependencies_of("S!Z9").is_empty());
215 }
216
217 #[test]
218 fn dependents_yields_the_reverse_map_kahn_needs() {
219 let mut dag = Dag::new();
220 dag.add_edge("S!C", "S!A");
221 dag.add_edge("S!C", "S!B");
222
223 let dependents = dag.dependents();
224 assert_eq!(dependents.get("S!A"), Some(&vec!["S!C".to_string()]));
225 assert_eq!(dependents.get("S!B"), Some(&vec!["S!C".to_string()]));
226 assert_eq!(dependents.get("S!C"), Some(&Vec::<String>::new()));
227 }
228
229 #[test]
230 fn add_node_registers_a_zero_dependency_node() {
231 let mut dag = Dag::new();
232 dag.add_node("S!A1");
233 assert_eq!(dag.nodes().count(), 1);
234 assert!(dag.dependencies_of("S!A1").is_empty());
235 }
236
237 #[test]
238 fn toposort_orders_dependencies_before_dependents() {
239 let mut dag = Dag::new();
241 dag.add_node("S!A");
242 dag.add_node("S!B");
243 dag.add_edge("S!C", "S!A");
244 dag.add_edge("S!C", "S!B");
245 let order = toposort(&dag).expect("acyclic");
246 let pos = |k: &str| order.iter().position(|n| n == k).unwrap();
247 assert!(pos("S!A") < pos("S!C"));
248 assert!(pos("S!B") < pos("S!C"));
249 }
250
251 #[test]
252 fn toposort_returns_residual_on_a_cycle() {
253 let mut dag = Dag::new();
254 dag.add_edge("S!A", "S!B");
255 dag.add_edge("S!B", "S!A");
256 let residual = toposort(&dag).expect_err("a cycle must be Err");
257 assert_eq!(residual, vec!["S!A".to_string(), "S!B".to_string()]);
258 }
259
260 fn inputs(keys: &[&str]) -> HashSet<String> {
263 keys.iter().map(|k| (*k).to_string()).collect()
264 }
265
266 fn leaves(set: &BTreeSet<String>) -> Vec<String> {
267 set.iter().cloned().collect()
268 }
269
270 #[test]
271 fn upstream_input_leaves_returns_exactly_reachable_inputs() {
272 let mut dag = Dag::new();
274 dag.add_edge("Calc!out", "Calc!f1");
275 dag.add_edge("Calc!f1", "In!income");
276 dag.add_edge("Calc!f1", "In!filing");
277 let input_cells = inputs(&["In!income", "In!filing", "In!withheld"]);
278 let got = upstream_input_leaves(&dag, "Calc!out", &input_cells);
279 assert_eq!(leaves(&got), vec!["In!filing", "In!income"]);
281 }
282
283 #[test]
284 fn upstream_input_leaves_excludes_constant_only_path() {
285 let mut dag = Dag::new();
288 dag.add_edge("Calc!out", "In!income");
289 dag.add_edge("Calc!out", "Const!rate"); dag.add_edge("Const!rate", "Const!base"); let input_cells = inputs(&["In!income"]);
292 let got = upstream_input_leaves(&dag, "Calc!out", &input_cells);
293 assert_eq!(
294 leaves(&got),
295 vec!["In!income"],
296 "a constant-only upstream path yields no input leaf"
297 );
298 }
299
300 #[test]
301 fn upstream_input_leaves_input_is_a_leaf_traversal_stops() {
302 let mut dag = Dag::new();
305 dag.add_edge("Calc!out", "In!income");
306 dag.add_edge("In!income", "Const!hidden"); let input_cells = inputs(&["In!income"]);
308 let got = upstream_input_leaves(&dag, "Calc!out", &input_cells);
309 assert_eq!(leaves(&got), vec!["In!income"]);
310 }
311
312 #[test]
313 fn upstream_input_leaves_shared_intermediate_unions_per_output() {
314 let mut dag = Dag::new();
319 dag.add_edge("Calc!tax", "Calc!shared");
320 dag.add_edge("Calc!tax", "In!filing");
321 dag.add_edge("Calc!refund", "Calc!shared");
322 dag.add_edge("Calc!refund", "In!withheld");
323 dag.add_edge("Calc!shared", "In!income");
324 let input_cells = inputs(&["In!income", "In!filing", "In!withheld"]);
325
326 let tax = upstream_input_leaves(&dag, "Calc!tax", &input_cells);
327 assert_eq!(
328 leaves(&tax),
329 vec!["In!filing", "In!income"],
330 "tax = its own upstream leaves (income via shared + filing)"
331 );
332 let refund = upstream_input_leaves(&dag, "Calc!refund", &input_cells);
333 assert_eq!(
334 leaves(&refund),
335 vec!["In!income", "In!withheld"],
336 "refund = its own upstream leaves (income via shared + withheld)"
337 );
338 }
339
340 #[test]
341 fn upstream_input_leaves_terminates_on_a_cycle() {
342 let mut dag = Dag::new();
345 dag.add_edge("Calc!out", "Calc!a");
346 dag.add_edge("Calc!a", "Calc!b");
347 dag.add_edge("Calc!b", "Calc!a"); dag.add_edge("Calc!a", "In!income");
349 let input_cells = inputs(&["In!income"]);
350 let got = upstream_input_leaves(&dag, "Calc!out", &input_cells);
351 assert_eq!(leaves(&got), vec!["In!income"]);
352 }
353
354 proptest::proptest! {
358 #[test]
359 fn prop_upstream_leaves_subset_and_reachable(
360 edges in proptest::collection::vec(
363 (0usize..12, 0usize..12),
364 0..40,
365 ),
366 input_mask in proptest::collection::vec(proptest::bool::ANY, 12),
367 ) {
368 let node = |i: usize| format!("N{i}");
369 let mut dag = Dag::new();
370 for i in 0..12 {
371 dag.add_node(&node(i));
372 }
373 for (from, dep) in &edges {
375 if dep < from {
376 dag.add_edge(&node(*from), &node(*dep));
377 }
378 }
379 let input_cells: HashSet<String> = (0..12)
380 .filter(|i| input_mask[*i])
381 .map(node)
382 .collect();
383 let output = node(11);
384 let got = upstream_input_leaves(&dag, &output, &input_cells);
385
386 for leaf in &got {
388 proptest::prop_assert!(
389 input_cells.contains(leaf),
390 "derived leaf {leaf} must be an input cell"
391 );
392 }
393
394 for leaf in &got {
397 let mut seen: HashSet<String> = HashSet::new();
398 let mut stack = vec![output.clone()];
399 let mut reached = false;
400 while let Some(c) = stack.pop() {
401 if c == *leaf {
402 reached = true;
403 break;
404 }
405 if !seen.insert(c.clone()) {
406 continue;
407 }
408 for d in dag.dependencies_of(&c) {
409 stack.push(d.clone());
410 }
411 }
412 proptest::prop_assert!(
413 reached,
414 "derived leaf {leaf} must be reachable upstream of {output}"
415 );
416 }
417 }
418 }
419}