formualizer_eval/engine/topo/
pk.rs

1use rustc_hash::{FxHashMap, FxHashSet};
2use std::cmp::Ordering;
3
4/// GraphView abstracts the conceptual DAG over which we maintain order.
5/// Implementations should provide successors (dependents) and predecessors (dependencies)
6/// for a node, using the engine's storage as the source of truth. The provided Vecs are
7/// scratch buffers owned by the caller; implementors should push into them without heap churn.
8pub trait GraphView<N: Copy + Eq + std::hash::Hash> {
9    fn successors(&self, n: N, out: &mut Vec<N>);
10    fn predecessors(&self, n: N, out: &mut Vec<N>);
11    fn exists(&self, n: N) -> bool;
12}
13
14#[derive(Debug, Clone)]
15pub struct Cycle<N> {
16    pub path: Vec<N>,
17}
18
19#[derive(Debug, Clone, Copy, Default)]
20pub struct PkStats {
21    pub relabeled: usize,
22    pub dfs_visited: usize,
23}
24
25#[derive(Debug, Clone, Copy)]
26pub struct PkConfig {
27    pub visit_budget: usize,
28    pub compaction_interval_ops: u64,
29}
30
31impl Default for PkConfig {
32    fn default() -> Self {
33        Self {
34            visit_budget: 50_000,
35            compaction_interval_ops: 100_000,
36        }
37    }
38}
39
40/// DynamicTopo maintains a deterministic total order (pos) consistent with the conceptual DAG.
41#[derive(Debug)]
42pub struct DynamicTopo<N: Copy + Eq + std::hash::Hash + Ord> {
43    pos: FxHashMap<N, u32>,
44    order: Vec<N>,
45    op_count: u64,
46    cfg: PkConfig,
47    // scratch
48    succ_buf: Vec<N>,
49    pred_buf: Vec<N>,
50}
51
52impl<N> DynamicTopo<N>
53where
54    N: Copy + Eq + std::hash::Hash + Ord,
55{
56    pub fn new(nodes: impl IntoIterator<Item = N>, cfg: PkConfig) -> Self {
57        let mut order: Vec<N> = nodes.into_iter().collect();
58        order.sort(); // stable deterministic seed order
59        let mut pos = FxHashMap::default();
60        for (i, n) in order.iter().enumerate() {
61            pos.insert(*n, i as u32);
62        }
63        Self {
64            pos,
65            order,
66            op_count: 0,
67            cfg,
68            succ_buf: Vec::new(),
69            pred_buf: Vec::new(),
70        }
71    }
72
73    /// Full rebuild via Kahn-style topological sort. Breaks ties by node Ord.
74    pub fn rebuild_full<G: GraphView<N>>(&mut self, graph: &G) {
75        // Collect existing nodes; drop non-existent
76        let mut nodes: Vec<N> = self
77            .order
78            .iter()
79            .copied()
80            .filter(|n| graph.exists(*n))
81            .collect();
82        nodes.sort();
83        let set: FxHashSet<N> = nodes.iter().copied().collect();
84
85        // Compute in-degrees within set
86        let mut indeg: FxHashMap<N, usize> = nodes.iter().map(|&n| (n, 0usize)).collect();
87        for &n in &nodes {
88            self.pred_buf.clear();
89            graph.predecessors(n, &mut self.pred_buf);
90            for &p in &self.pred_buf {
91                if set.contains(&p) {
92                    *indeg.get_mut(&n).unwrap() += 1;
93                }
94            }
95        }
96
97        let mut zero: Vec<N> = indeg
98            .iter()
99            .filter_map(|(&n, &d)| if d == 0 { Some(n) } else { None })
100            .collect();
101        // Sort descending so that pop() returns the smallest first
102        zero.sort_by(|a, b| b.cmp(a));
103
104        let mut out: Vec<N> = Vec::with_capacity(nodes.len());
105        while let Some(n) = zero.pop() {
106            out.push(n);
107            self.succ_buf.clear();
108            graph.successors(n, &mut self.succ_buf);
109            // limited to subset
110            for &s in &self.succ_buf {
111                if let Some(d) = indeg.get_mut(&s) {
112                    *d -= 1;
113                    if *d == 0 {
114                        zero.push(s);
115                    }
116                }
117            }
118            zero.sort_by(|a, b| b.cmp(a)); // maintain pop as smallest
119        }
120        // If out length < nodes, there is a cycle in source graph; keep relative sorted order
121        if out.len() != nodes.len() {
122            out = nodes; // fallback deterministic
123        }
124        self.order = out;
125        self.pos.clear();
126        for (i, n) in self.order.iter().enumerate() {
127            self.pos.insert(*n, i as u32);
128        }
129    }
130
131    pub fn try_add_edge<G: GraphView<N>>(
132        &mut self,
133        graph: &G,
134        x: N,
135        y: N,
136    ) -> Result<PkStats, Cycle<N>> {
137        if x == y {
138            return Err(Cycle { path: vec![x, y] });
139        }
140        let px = match self.pos.get(&x).copied() {
141            Some(v) => v,
142            None => self.add_missing(x),
143        };
144        let py = match self.pos.get(&y).copied() {
145            Some(v) => v,
146            None => self.add_missing(y),
147        };
148        if px < py {
149            return Ok(PkStats::default());
150        }
151
152        // Limited DFS from y through successors with pos <= px
153        let mut stack: Vec<N> = vec![y];
154        let mut parent: FxHashMap<N, N> = FxHashMap::default();
155        let mut visited: FxHashSet<N> = FxHashSet::default();
156        let mut affected: Vec<N> = Vec::new();
157        let mut visited_cnt = 0usize;
158
159        while let Some(u) = stack.pop() {
160            if !visited.insert(u) {
161                continue;
162            }
163            visited_cnt += 1;
164            if visited_cnt > self.cfg.visit_budget {
165                self.rebuild_full(graph);
166                self.op_count += 1;
167                // After rebuild, re-check quickly, recurse once
168                return self.try_add_edge(graph, x, y);
169            }
170            if u == x {
171                // build path from y -> ... -> x
172                let mut path = vec![x];
173                let mut cur = x;
174                while cur != y {
175                    cur = *parent.get(&cur).unwrap();
176                    path.push(cur);
177                }
178                path.reverse();
179                return Err(Cycle { path });
180            }
181            affected.push(u);
182            self.succ_buf.clear();
183            graph.successors(u, &mut self.succ_buf);
184            for &s in &self.succ_buf {
185                if let Some(&ps) = self.pos.get(&s)
186                    && ps <= px
187                    && !visited.contains(&s)
188                {
189                    parent.insert(s, u);
190                    stack.push(s);
191                }
192            }
193        }
194
195        // splice affected block to just after px in global order, maintaining relative order
196        let relabeled = self.splice_after(px as usize, &affected);
197
198        self.op_count += 1;
199        if self
200            .op_count
201            .is_multiple_of(self.cfg.compaction_interval_ops)
202        {
203            self.compact_ranks();
204        }
205        Ok(PkStats {
206            relabeled,
207            dfs_visited: visited_cnt,
208        })
209    }
210
211    pub fn remove_edge(&mut self, _x: N, _y: N) {
212        // PK does not require reorder on deletion.
213        self.op_count += 1;
214        if self
215            .op_count
216            .is_multiple_of(self.cfg.compaction_interval_ops)
217        {
218            self.compact_ranks();
219        }
220    }
221
222    pub fn apply_bulk<G: GraphView<N>>(
223        &mut self,
224        graph: &G,
225        removes: &[(N, N)],
226        adds: &[(N, N)],
227    ) -> Result<PkStats, Cycle<N>> {
228        for &(x, y) in removes {
229            let _ = (x, y);
230            self.remove_edge(x, y);
231        }
232        let mut stats = PkStats::default();
233        for &(x, y) in adds {
234            match self.try_add_edge(graph, x, y) {
235                Ok(s) => {
236                    stats.relabeled += s.relabeled;
237                    stats.dfs_visited += s.dfs_visited;
238                }
239                Err(c) => {
240                    return Err(c);
241                }
242            }
243        }
244        Ok(stats)
245    }
246
247    #[inline]
248    pub fn topo_order(&self) -> &[N] {
249        &self.order
250    }
251
252    pub fn compact_ranks(&mut self) {
253        // Re-impose deterministic order: stable sort by (current pos, then N)
254        self.order
255            .sort_by(|a, b| match self.pos[a].cmp(&self.pos[b]) {
256                Ordering::Equal => a.cmp(b),
257                o => o,
258            });
259        self.pos.clear();
260        for (i, n) in self.order.iter().enumerate() {
261            self.pos.insert(*n, i as u32);
262        }
263    }
264
265    /// Build parallel-ready layers for a subset, using maintained order for tie-breaks.
266    pub fn layers_for<G: GraphView<N>>(
267        &self,
268        graph: &G,
269        subset: &[N],
270        max_layer_width: Option<usize>,
271    ) -> Vec<Vec<N>> {
272        if subset.is_empty() {
273            return Vec::new();
274        }
275        let subset_set: FxHashSet<N> = subset.iter().copied().collect();
276        let mut indeg: FxHashMap<N, usize> = subset.iter().map(|&n| (n, 0usize)).collect();
277        let mut pred_buf = Vec::new();
278        for &n in subset {
279            pred_buf.clear();
280            // SAFETY: We don't have &mut self graph; create a temp adapter using trait buffers
281            // but GraphView requires &self; we only need predecessors, pass scratch buffer.
282            // Here we can't call self.graph.predecessors directly because we don't have &mut; it's &self fine.
283            // But GraphView signature already takes &self and &mut Vec.
284            graph.predecessors(n, &mut pred_buf);
285            for &p in &pred_buf {
286                if subset_set.contains(&p) {
287                    *indeg.get_mut(&n).unwrap() += 1;
288                }
289            }
290        }
291        let mut zero: Vec<N> = indeg
292            .iter()
293            .filter_map(|(&n, &d)| if d == 0 { Some(n) } else { None })
294            .collect();
295        // Deterministic: by current position, then N
296        zero.sort_by(|a, b| self.pos[a].cmp(&self.pos[b]).then_with(|| a.cmp(b)));
297
298        let mut layers: Vec<Vec<N>> = Vec::new();
299        let mut succ_buf = Vec::new();
300        while !zero.is_empty() {
301            let mut layer = Vec::new();
302            let cap = max_layer_width.unwrap_or(usize::MAX);
303            for _ in 0..zero.len().min(cap) {
304                layer.push(zero.remove(0));
305            }
306            // within layer, sort deterministically
307            layer.sort_by(|a, b| self.pos[a].cmp(&self.pos[b]).then_with(|| a.cmp(b)));
308
309            for &u in &layer {
310                succ_buf.clear();
311                graph.successors(u, &mut succ_buf);
312                for &v in &succ_buf {
313                    if let Some(d) = indeg.get_mut(&v) {
314                        *d -= 1;
315                        if *d == 0 {
316                            zero.push(v);
317                        }
318                    }
319                }
320            }
321            zero.sort_by(|a, b| self.pos[a].cmp(&self.pos[b]).then_with(|| a.cmp(b)));
322            layers.push(layer);
323        }
324        // Any remaining indeg>0 would be a logic bug (cycles should be caught earlier).
325        layers
326    }
327
328    #[inline]
329    fn add_missing(&mut self, n: N) -> u32 {
330        let idx = self.order.len() as u32;
331        self.order.push(n);
332        self.pos.insert(n, idx);
333        idx
334    }
335
336    /// Ensure that all provided nodes exist in the ordering; append any missing at the end.
337    pub fn ensure_nodes(&mut self, nodes: impl IntoIterator<Item = N>) {
338        for n in nodes {
339            if !self.pos.contains_key(&n) {
340                self.add_missing(n);
341            }
342        }
343    }
344
345    /// Move all nodes in `affected` to just after index `after_pos`, preserving their internal order.
346    /// Returns number of nodes relabeled.
347    fn splice_after(&mut self, after_pos: usize, affected: &[N]) -> usize {
348        if affected.is_empty() {
349            return 0;
350        }
351        let mark: FxHashSet<N> = affected.iter().copied().collect();
352        // Extract unaffected and affected in original order
353        let mut left: Vec<N> = Vec::with_capacity(self.order.len() - affected.len());
354        let mut block: Vec<N> = Vec::with_capacity(affected.len());
355        for &n in &self.order {
356            if mark.contains(&n) {
357                block.push(n);
358            } else {
359                left.push(n);
360            }
361        }
362        // Build new order by placing block after after_pos in the left vector
363        let mut new_order: Vec<N> = Vec::with_capacity(self.order.len());
364        let mut i = 0usize;
365        while i < left.len() {
366            new_order.push(left[i]);
367            i += 1;
368            if i - 1 == after_pos {
369                new_order.extend(block.iter().copied());
370            }
371        }
372        if after_pos >= left.len() {
373            new_order.extend(block.iter().copied());
374        }
375        self.order = new_order;
376        // Recompute positions for moved nodes only; but for simplicity recompute all
377        for (i, &n) in self.order.iter().enumerate() {
378            self.pos.insert(n, i as u32);
379        }
380        affected.len()
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use rustc_hash::FxHashMap;
388
389    #[derive(Default)]
390    struct SimpleGraph {
391        succ: FxHashMap<u32, Vec<u32>>, // x -> [y]
392        pred: FxHashMap<u32, Vec<u32>>, // y -> [x]
393    }
394
395    impl SimpleGraph {
396        fn add_edge(&mut self, x: u32, y: u32) {
397            self.succ.entry(x).or_default().push(y);
398            self.pred.entry(y).or_default().push(x);
399        }
400
401        fn remove_edge(&mut self, x: u32, y: u32) {
402            if let Some(v) = self.succ.get_mut(&x) {
403                v.retain(|&t| t != y);
404            }
405            if let Some(v) = self.pred.get_mut(&y) {
406                v.retain(|&s| s != x);
407            }
408        }
409    }
410
411    impl GraphView<u32> for SimpleGraph {
412        fn successors(&self, n: u32, out: &mut Vec<u32>) {
413            out.clear();
414            if let Some(v) = self.succ.get(&n) {
415                out.extend(v.iter().copied());
416            }
417        }
418        fn predecessors(&self, n: u32, out: &mut Vec<u32>) {
419            out.clear();
420            if let Some(v) = self.pred.get(&n) {
421                out.extend(v.iter().copied());
422            }
423        }
424        fn exists(&self, _n: u32) -> bool {
425            true
426        }
427    }
428
429    fn idx(order: &[u32], n: u32) -> usize {
430        order.iter().position(|&x| x == n).unwrap()
431    }
432
433    #[test]
434    fn rebuild_full_basic_chain() {
435        let mut g = SimpleGraph::default();
436        g.add_edge(1, 2);
437        g.add_edge(2, 3);
438        let nodes = [1, 2, 3, 4];
439        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
440        pk.rebuild_full(&g);
441        let layers = pk.layers_for(&g, &nodes, None);
442        // Expect chain 1->2->3 and 4 independent; first layer contains 1 and 4
443        assert!(!layers.is_empty());
444        assert!(layers[0].contains(&1));
445        // 2 must come after 1, 3 after 2 in order
446        let order = pk.topo_order().to_vec();
447        assert!(idx(&order, 1) < idx(&order, 2));
448        assert!(idx(&order, 2) < idx(&order, 3));
449    }
450
451    #[test]
452    fn add_edge_forward_no_relabel() {
453        let g = SimpleGraph::default();
454        let nodes = [1, 2, 3, 4];
455        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
456        pk.rebuild_full(&g);
457        let stats = pk.try_add_edge(&g, 1, 3).unwrap();
458        assert_eq!(stats.relabeled, 0);
459        // Order remains a valid topo order: 1 before 3
460        let order = pk.topo_order();
461        assert!(idx(order, 1) < idx(order, 3));
462    }
463
464    #[test]
465    fn add_edge_backedge_splices_without_cycle() {
466        // No existing path from 2 to 3, so adding 3->2 should reorder placing 2 after 3
467        let g = SimpleGraph::default();
468        let nodes = [1, 2, 3, 4];
469        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
470        pk.rebuild_full(&g);
471        let _ = pk.try_add_edge(&g, 3, 2).unwrap();
472        let order = pk.topo_order();
473        assert!(idx(order, 3) < idx(order, 2));
474    }
475
476    #[test]
477    fn detect_cycle_on_add() {
478        let mut g = SimpleGraph::default();
479        g.add_edge(2, 3);
480        g.add_edge(3, 4);
481        let nodes = [1, 2, 3, 4];
482        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
483        pk.rebuild_full(&g);
484        // Adding 4->2 creates a cycle 2->3->4->2
485        let res = pk.try_add_edge(&g, 4, 2);
486        assert!(res.is_err());
487    }
488
489    #[test]
490    fn layers_with_width_cap() {
491        let mut g = SimpleGraph::default();
492        // 1->3, 2->3, 4 independent
493        g.add_edge(1, 3);
494        g.add_edge(2, 3);
495        let nodes = [1, 2, 3, 4];
496        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
497        pk.rebuild_full(&g);
498        let layers = pk.layers_for(&g, &nodes, Some(1));
499        // Each layer must have at most 1 due to cap
500        assert!(layers.iter().all(|layer| layer.len() <= 1));
501        // Union of all layers equals the subset
502        let mut all: Vec<u32> = layers.into_iter().flatten().collect();
503        all.sort();
504        assert_eq!(all, vec![1, 2, 3, 4]);
505    }
506
507    #[test]
508    fn layers_unbounded_expected_first_layer() {
509        let mut g = SimpleGraph::default();
510        g.add_edge(1, 3);
511        g.add_edge(2, 3); // zeros: 1,2,4
512        let nodes = [1, 2, 3, 4];
513        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
514        pk.rebuild_full(&g);
515        let layers = pk.layers_for(&g, &nodes, None);
516        assert!(!layers.is_empty());
517        // First layer should contain exactly {1,2,4} in some deterministic order
518        let mut got = layers[0].clone();
519        got.sort();
520        assert_eq!(got, vec![1, 2, 4]);
521        assert_eq!(layers[1], vec![3]);
522    }
523
524    #[test]
525    fn apply_bulk_adds_then_layers() {
526        let mut g = SimpleGraph::default();
527        let nodes = [1, 2, 3];
528        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
529        pk.rebuild_full(&g);
530        g.add_edge(1, 2);
531        g.add_edge(2, 3);
532        let _stats = pk.apply_bulk(&g, &[], &[(1, 2), (2, 3)]).unwrap();
533        let layers = pk.layers_for(&g, &nodes, None);
534        assert_eq!(layers.len(), 3);
535    }
536
537    #[test]
538    fn budget_config_is_respected_in_api_surface() {
539        let g = SimpleGraph::default();
540        let mut pk = DynamicTopo::new(
541            [1u32, 2, 3, 4, 5],
542            PkConfig {
543                visit_budget: 1,
544                compaction_interval_ops: 10,
545            },
546        );
547        pk.rebuild_full(&g);
548        assert_eq!(pk.cfg.visit_budget, 1);
549    }
550
551    #[test]
552    fn ensure_nodes_appends_missing() {
553        let g = SimpleGraph::default();
554        let mut pk = DynamicTopo::new([1u32], PkConfig::default());
555        pk.ensure_nodes([2u32, 3u32]);
556        let order = pk.topo_order();
557        assert_eq!(order.len(), 3);
558        assert!(idx(order, 1) < idx(order, 2));
559        assert!(idx(order, 2) < idx(order, 3));
560    }
561
562    #[test]
563    fn compact_ranks_keeps_nodes() {
564        let g = SimpleGraph::default();
565        let nodes = [3, 1, 4, 2];
566        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
567        pk.rebuild_full(&g);
568        pk.compact_ranks();
569        let mut order = pk.topo_order().to_vec();
570        order.sort();
571        assert_eq!(order, vec![1, 2, 3, 4]);
572    }
573
574    #[test]
575    fn remove_edge_does_not_change_order() {
576        let mut g = SimpleGraph::default();
577        g.add_edge(1, 3);
578        let nodes = [1, 2, 3];
579        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
580        pk.rebuild_full(&g);
581        let before = pk.topo_order().to_vec();
582        pk.remove_edge(2, 1); // unrelated removal
583        let after = pk.topo_order().to_vec();
584        assert_eq!(before, after);
585    }
586
587    #[test]
588    fn apply_bulk_mixed_removes_and_adds() {
589        // Start with a chain 1->2->3; then remove (1,2) and add (1,3)
590        let mut g = SimpleGraph::default();
591        g.add_edge(1, 2);
592        g.add_edge(2, 3);
593        let nodes = [1, 2, 3, 4];
594        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
595        pk.rebuild_full(&g);
596
597        // Prime PK with current edges
598        let _ = pk.apply_bulk(&g, &[], &[(1, 2), (2, 3)]).unwrap();
599
600        // Mutate graph: remove (1,2), add (1,3)
601        g.remove_edge(1, 2);
602        g.add_edge(1, 3);
603        let stats = pk.apply_bulk(&g, &[(1, 2)], &[(1, 3)]).unwrap();
604        // After changes, at minimum 1 must precede 3; 2 can be anywhere relative to 1
605        let order = pk.topo_order().to_vec();
606        assert!(idx(&order, 1) < idx(&order, 3));
607        // stats existence was verified by unwrap(); nothing else to assert here.
608    }
609
610    #[test]
611    fn layers_for_subset_only() {
612        // Full graph: 1->2, 2->3, 4->5; subset: [2,3,5]
613        let mut g = SimpleGraph::default();
614        g.add_edge(1, 2);
615        g.add_edge(2, 3);
616        g.add_edge(4, 5);
617        let nodes = [1, 2, 3, 4, 5];
618        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
619        pk.rebuild_full(&g);
620        // Build layers only for subset; expect 2 before 3; 5 has indegree 1 from 4 (outside subset) so indegree 0 in subset
621        let subset = vec![2, 3, 5];
622        let layers = pk.layers_for(&g, &subset, None);
623        // Flatten and compare membership equals subset
624        let mut flat: Vec<u32> = layers.iter().flatten().copied().collect();
625        flat.sort();
626        assert_eq!(flat, vec![2, 3, 5]);
627        // 2 should be in a layer before 3
628        let pos2 = layers.iter().position(|lay| lay.contains(&2)).unwrap();
629        let pos3 = layers.iter().position(|lay| lay.contains(&3)).unwrap();
630        assert!(pos2 < pos3);
631        // 5 should be in the first layer (no in-subset predecessors)
632        assert!(layers[0].contains(&5));
633    }
634
635    #[test]
636    fn compact_ranks_repeated_stability() {
637        // After establishing some ordering pressure, repeated compactions shouldn't change order
638        let mut g = SimpleGraph::default();
639        g.add_edge(1, 3);
640        g.add_edge(2, 3);
641        let nodes = [1, 2, 3, 4];
642        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
643        pk.rebuild_full(&g);
644        let _ = pk.try_add_edge(&g, 4, 2); // introduce a back-edge reorder (4 before 2)
645        let baseline = pk.topo_order().to_vec();
646        for _ in 0..10 {
647            pk.compact_ranks();
648            assert_eq!(baseline, pk.topo_order());
649        }
650    }
651
652    #[test]
653    fn compact_ranks_is_stable_repeated() {
654        let g = SimpleGraph::default();
655        let nodes = [1, 2, 3, 4];
656        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
657        pk.rebuild_full(&g);
658        // Create a back-edge to reorder: 3->2
659        let _ = pk.try_add_edge(&g, 3, 2).unwrap();
660        let before = pk.topo_order().to_vec();
661        for _ in 0..5 {
662            pk.compact_ranks();
663            assert_eq!(pk.topo_order(), &before);
664        }
665    }
666}