formualizer_eval/engine/topo/
pk.rs

1use rustc_hash::{FxHashMap, FxHashSet};
2use std::cmp::Ordering;
3
4/// GraphView abstracts the conceptual DAG over which we maintain order.
5/// Implementations should provide successors (dependents) and predecessors (dependencies)
6/// for a node, using the engine's storage as the source of truth. The provided Vecs are
7/// scratch buffers owned by the caller; implementors should push into them without heap churn.
8pub trait GraphView<N: Copy + Eq + std::hash::Hash> {
9    fn successors(&self, n: N, out: &mut Vec<N>);
10    fn predecessors(&self, n: N, out: &mut Vec<N>);
11    fn exists(&self, n: N) -> bool;
12}
13
14#[derive(Debug, Clone)]
15pub struct Cycle<N> {
16    pub path: Vec<N>,
17}
18
19#[derive(Debug, Clone, Copy, Default)]
20pub struct PkStats {
21    pub relabeled: usize,
22    pub dfs_visited: usize,
23}
24
25#[derive(Debug, Clone, Copy)]
26pub struct PkConfig {
27    pub visit_budget: usize,
28    pub compaction_interval_ops: u64,
29}
30
31impl Default for PkConfig {
32    fn default() -> Self {
33        Self {
34            visit_budget: 50_000,
35            compaction_interval_ops: 100_000,
36        }
37    }
38}
39
40/// DynamicTopo maintains a deterministic total order (pos) consistent with the conceptual DAG.
41#[derive(Debug)]
42pub struct DynamicTopo<N: Copy + Eq + std::hash::Hash + Ord> {
43    pos: FxHashMap<N, u32>,
44    order: Vec<N>,
45    op_count: u64,
46    cfg: PkConfig,
47    // scratch
48    succ_buf: Vec<N>,
49    pred_buf: Vec<N>,
50}
51
52impl<N> DynamicTopo<N>
53where
54    N: Copy + Eq + std::hash::Hash + Ord,
55{
56    pub fn new(nodes: impl IntoIterator<Item = N>, cfg: PkConfig) -> Self {
57        let mut order: Vec<N> = nodes.into_iter().collect();
58        order.sort(); // stable deterministic seed order
59        let mut pos = FxHashMap::default();
60        for (i, n) in order.iter().enumerate() {
61            pos.insert(*n, i as u32);
62        }
63        Self {
64            pos,
65            order,
66            op_count: 0,
67            cfg,
68            succ_buf: Vec::new(),
69            pred_buf: Vec::new(),
70        }
71    }
72
73    /// Full rebuild via Kahn-style topological sort. Breaks ties by node Ord.
74    pub fn rebuild_full<G: GraphView<N>>(&mut self, graph: &G) {
75        // Collect existing nodes; drop non-existent
76        let mut nodes: Vec<N> = self
77            .order
78            .iter()
79            .copied()
80            .filter(|n| graph.exists(*n))
81            .collect();
82        nodes.sort();
83        let set: FxHashSet<N> = nodes.iter().copied().collect();
84
85        // Compute in-degrees within set
86        let mut indeg: FxHashMap<N, usize> = nodes.iter().map(|&n| (n, 0usize)).collect();
87        for &n in &nodes {
88            self.pred_buf.clear();
89            graph.predecessors(n, &mut self.pred_buf);
90            for &p in &self.pred_buf {
91                if set.contains(&p) {
92                    *indeg.get_mut(&n).unwrap() += 1;
93                }
94            }
95        }
96
97        let mut zero: Vec<N> = indeg
98            .iter()
99            .filter_map(|(&n, &d)| if d == 0 { Some(n) } else { None })
100            .collect();
101        // Sort descending so that pop() returns the smallest first
102        zero.sort_by(|a, b| b.cmp(a));
103
104        let mut out: Vec<N> = Vec::with_capacity(nodes.len());
105        while let Some(n) = zero.pop() {
106            out.push(n);
107            self.succ_buf.clear();
108            graph.successors(n, &mut self.succ_buf);
109            // limited to subset
110            for &s in &self.succ_buf {
111                if let Some(d) = indeg.get_mut(&s) {
112                    *d -= 1;
113                    if *d == 0 {
114                        zero.push(s);
115                    }
116                }
117            }
118            zero.sort_by(|a, b| b.cmp(a)); // maintain pop as smallest
119        }
120        // If out length < nodes, there is a cycle in source graph; keep relative sorted order
121        if out.len() != nodes.len() {
122            out = nodes; // fallback deterministic
123        }
124        self.order = out;
125        self.pos.clear();
126        for (i, n) in self.order.iter().enumerate() {
127            self.pos.insert(*n, i as u32);
128        }
129    }
130
131    pub fn try_add_edge<G: GraphView<N>>(
132        &mut self,
133        graph: &G,
134        x: N,
135        y: N,
136    ) -> Result<PkStats, Cycle<N>> {
137        if x == y {
138            return Err(Cycle { path: vec![x, y] });
139        }
140        let px = match self.pos.get(&x).copied() {
141            Some(v) => v,
142            None => self.add_missing(x),
143        };
144        let py = match self.pos.get(&y).copied() {
145            Some(v) => v,
146            None => self.add_missing(y),
147        };
148        if px < py {
149            return Ok(PkStats::default());
150        }
151
152        // Limited DFS from y through successors with pos <= px
153        let mut stack: Vec<N> = vec![y];
154        let mut parent: FxHashMap<N, N> = FxHashMap::default();
155        let mut visited: FxHashSet<N> = FxHashSet::default();
156        let mut affected: Vec<N> = Vec::new();
157        let mut visited_cnt = 0usize;
158
159        while let Some(u) = stack.pop() {
160            if !visited.insert(u) {
161                continue;
162            }
163            visited_cnt += 1;
164            if visited_cnt > self.cfg.visit_budget {
165                self.rebuild_full(graph);
166                self.op_count += 1;
167                // After rebuild, re-check quickly, recurse once
168                return self.try_add_edge(graph, x, y);
169            }
170            if u == x {
171                // build path from y -> ... -> x
172                let mut path = vec![x];
173                let mut cur = x;
174                while cur != y {
175                    cur = *parent.get(&cur).unwrap();
176                    path.push(cur);
177                }
178                path.reverse();
179                return Err(Cycle { path });
180            }
181            affected.push(u);
182            self.succ_buf.clear();
183            graph.successors(u, &mut self.succ_buf);
184            for &s in &self.succ_buf {
185                if let Some(&ps) = self.pos.get(&s) {
186                    if ps <= px && !visited.contains(&s) {
187                        parent.insert(s, u);
188                        stack.push(s);
189                    }
190                }
191            }
192        }
193
194        // splice affected block to just after px in global order, maintaining relative order
195        let relabeled = self.splice_after(px as usize, &affected);
196
197        self.op_count += 1;
198        if self.op_count % self.cfg.compaction_interval_ops == 0 {
199            self.compact_ranks();
200        }
201        Ok(PkStats {
202            relabeled,
203            dfs_visited: visited_cnt,
204        })
205    }
206
207    pub fn remove_edge(&mut self, _x: N, _y: N) {
208        // PK does not require reorder on deletion.
209        self.op_count += 1;
210        if self.op_count % self.cfg.compaction_interval_ops == 0 {
211            self.compact_ranks();
212        }
213    }
214
215    pub fn apply_bulk<G: GraphView<N>>(
216        &mut self,
217        graph: &G,
218        removes: &[(N, N)],
219        adds: &[(N, N)],
220    ) -> Result<PkStats, Cycle<N>> {
221        for &(x, y) in removes {
222            let _ = (x, y);
223            self.remove_edge(x, y);
224        }
225        let mut stats = PkStats::default();
226        for &(x, y) in adds {
227            match self.try_add_edge(graph, x, y) {
228                Ok(s) => {
229                    stats.relabeled += s.relabeled;
230                    stats.dfs_visited += s.dfs_visited;
231                }
232                Err(c) => {
233                    return Err(c);
234                }
235            }
236        }
237        Ok(stats)
238    }
239
240    #[inline]
241    pub fn topo_order(&self) -> &[N] {
242        &self.order
243    }
244
245    pub fn compact_ranks(&mut self) {
246        // Re-impose deterministic order: stable sort by (current pos, then N)
247        self.order
248            .sort_by(|a, b| match self.pos[a].cmp(&self.pos[b]) {
249                Ordering::Equal => a.cmp(b),
250                o => o,
251            });
252        self.pos.clear();
253        for (i, n) in self.order.iter().enumerate() {
254            self.pos.insert(*n, i as u32);
255        }
256    }
257
258    /// Build parallel-ready layers for a subset, using maintained order for tie-breaks.
259    pub fn layers_for<G: GraphView<N>>(
260        &self,
261        graph: &G,
262        subset: &[N],
263        max_layer_width: Option<usize>,
264    ) -> Vec<Vec<N>> {
265        if subset.is_empty() {
266            return Vec::new();
267        }
268        let subset_set: FxHashSet<N> = subset.iter().copied().collect();
269        let mut indeg: FxHashMap<N, usize> = subset.iter().map(|&n| (n, 0usize)).collect();
270        let mut pred_buf = Vec::new();
271        for &n in subset {
272            pred_buf.clear();
273            // SAFETY: We don't have &mut self graph; create a temp adapter using trait buffers
274            // but GraphView requires &self; we only need predecessors, pass scratch buffer.
275            // Here we can't call self.graph.predecessors directly because we don't have &mut; it's &self fine.
276            // But GraphView signature already takes &self and &mut Vec.
277            graph.predecessors(n, &mut pred_buf);
278            for &p in &pred_buf {
279                if subset_set.contains(&p) {
280                    *indeg.get_mut(&n).unwrap() += 1;
281                }
282            }
283        }
284        let mut zero: Vec<N> = indeg
285            .iter()
286            .filter_map(|(&n, &d)| if d == 0 { Some(n) } else { None })
287            .collect();
288        // Deterministic: by current position, then N
289        zero.sort_by(|a, b| self.pos[a].cmp(&self.pos[b]).then_with(|| a.cmp(b)));
290
291        let mut layers: Vec<Vec<N>> = Vec::new();
292        let mut succ_buf = Vec::new();
293        while !zero.is_empty() {
294            let mut layer = Vec::new();
295            let cap = max_layer_width.unwrap_or(usize::MAX);
296            for _ in 0..zero.len().min(cap) {
297                layer.push(zero.remove(0));
298            }
299            // within layer, sort deterministically
300            layer.sort_by(|a, b| self.pos[a].cmp(&self.pos[b]).then_with(|| a.cmp(b)));
301
302            for &u in &layer {
303                succ_buf.clear();
304                graph.successors(u, &mut succ_buf);
305                for &v in &succ_buf {
306                    if let Some(d) = indeg.get_mut(&v) {
307                        *d -= 1;
308                        if *d == 0 {
309                            zero.push(v);
310                        }
311                    }
312                }
313            }
314            zero.sort_by(|a, b| self.pos[a].cmp(&self.pos[b]).then_with(|| a.cmp(b)));
315            layers.push(layer);
316        }
317        // Any remaining indeg>0 would be a logic bug (cycles should be caught earlier).
318        layers
319    }
320
321    #[inline]
322    fn add_missing(&mut self, n: N) -> u32 {
323        let idx = self.order.len() as u32;
324        self.order.push(n);
325        self.pos.insert(n, idx);
326        idx
327    }
328
329    /// Ensure that all provided nodes exist in the ordering; append any missing at the end.
330    pub fn ensure_nodes(&mut self, nodes: impl IntoIterator<Item = N>) {
331        for n in nodes {
332            if !self.pos.contains_key(&n) {
333                self.add_missing(n);
334            }
335        }
336    }
337
338    /// Move all nodes in `affected` to just after index `after_pos`, preserving their internal order.
339    /// Returns number of nodes relabeled.
340    fn splice_after(&mut self, after_pos: usize, affected: &[N]) -> usize {
341        if affected.is_empty() {
342            return 0;
343        }
344        let mark: FxHashSet<N> = affected.iter().copied().collect();
345        // Extract unaffected and affected in original order
346        let mut left: Vec<N> = Vec::with_capacity(self.order.len() - affected.len());
347        let mut block: Vec<N> = Vec::with_capacity(affected.len());
348        for &n in &self.order {
349            if mark.contains(&n) {
350                block.push(n);
351            } else {
352                left.push(n);
353            }
354        }
355        // Build new order by placing block after after_pos in the left vector
356        let mut new_order: Vec<N> = Vec::with_capacity(self.order.len());
357        let mut i = 0usize;
358        while i < left.len() {
359            new_order.push(left[i]);
360            i += 1;
361            if i - 1 == after_pos {
362                new_order.extend(block.iter().copied());
363            }
364        }
365        if after_pos >= left.len() {
366            new_order.extend(block.iter().copied());
367        }
368        self.order = new_order;
369        // Recompute positions for moved nodes only; but for simplicity recompute all
370        for (i, &n) in self.order.iter().enumerate() {
371            self.pos.insert(n, i as u32);
372        }
373        affected.len()
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use rustc_hash::FxHashMap;
381
382    #[derive(Default)]
383    struct SimpleGraph {
384        succ: FxHashMap<u32, Vec<u32>>, // x -> [y]
385        pred: FxHashMap<u32, Vec<u32>>, // y -> [x]
386    }
387
388    impl SimpleGraph {
389        fn add_edge(&mut self, x: u32, y: u32) {
390            self.succ.entry(x).or_default().push(y);
391            self.pred.entry(y).or_default().push(x);
392        }
393
394        fn remove_edge(&mut self, x: u32, y: u32) {
395            if let Some(v) = self.succ.get_mut(&x) {
396                v.retain(|&t| t != y);
397            }
398            if let Some(v) = self.pred.get_mut(&y) {
399                v.retain(|&s| s != x);
400            }
401        }
402    }
403
404    impl GraphView<u32> for SimpleGraph {
405        fn successors(&self, n: u32, out: &mut Vec<u32>) {
406            out.clear();
407            if let Some(v) = self.succ.get(&n) {
408                out.extend(v.iter().copied());
409            }
410        }
411        fn predecessors(&self, n: u32, out: &mut Vec<u32>) {
412            out.clear();
413            if let Some(v) = self.pred.get(&n) {
414                out.extend(v.iter().copied());
415            }
416        }
417        fn exists(&self, _n: u32) -> bool {
418            true
419        }
420    }
421
422    fn idx(order: &[u32], n: u32) -> usize {
423        order.iter().position(|&x| x == n).unwrap()
424    }
425
426    #[test]
427    fn rebuild_full_basic_chain() {
428        let mut g = SimpleGraph::default();
429        g.add_edge(1, 2);
430        g.add_edge(2, 3);
431        let nodes = [1, 2, 3, 4];
432        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
433        pk.rebuild_full(&g);
434        let layers = pk.layers_for(&g, &nodes, None);
435        // Expect chain 1->2->3 and 4 independent; first layer contains 1 and 4
436        assert!(!layers.is_empty());
437        assert!(layers[0].contains(&1));
438        // 2 must come after 1, 3 after 2 in order
439        let order = pk.topo_order().to_vec();
440        assert!(idx(&order, 1) < idx(&order, 2));
441        assert!(idx(&order, 2) < idx(&order, 3));
442    }
443
444    #[test]
445    fn add_edge_forward_no_relabel() {
446        let g = SimpleGraph::default();
447        let nodes = [1, 2, 3, 4];
448        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
449        pk.rebuild_full(&g);
450        let stats = pk.try_add_edge(&g, 1, 3).unwrap();
451        assert_eq!(stats.relabeled, 0);
452        // Order remains a valid topo order: 1 before 3
453        let order = pk.topo_order();
454        assert!(idx(order, 1) < idx(order, 3));
455    }
456
457    #[test]
458    fn add_edge_backedge_splices_without_cycle() {
459        // No existing path from 2 to 3, so adding 3->2 should reorder placing 2 after 3
460        let g = SimpleGraph::default();
461        let nodes = [1, 2, 3, 4];
462        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
463        pk.rebuild_full(&g);
464        let _ = pk.try_add_edge(&g, 3, 2).unwrap();
465        let order = pk.topo_order();
466        assert!(idx(order, 3) < idx(order, 2));
467    }
468
469    #[test]
470    fn detect_cycle_on_add() {
471        let mut g = SimpleGraph::default();
472        g.add_edge(2, 3);
473        g.add_edge(3, 4);
474        let nodes = [1, 2, 3, 4];
475        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
476        pk.rebuild_full(&g);
477        // Adding 4->2 creates a cycle 2->3->4->2
478        let res = pk.try_add_edge(&g, 4, 2);
479        assert!(res.is_err());
480    }
481
482    #[test]
483    fn layers_with_width_cap() {
484        let mut g = SimpleGraph::default();
485        // 1->3, 2->3, 4 independent
486        g.add_edge(1, 3);
487        g.add_edge(2, 3);
488        let nodes = [1, 2, 3, 4];
489        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
490        pk.rebuild_full(&g);
491        let layers = pk.layers_for(&g, &nodes, Some(1));
492        // Each layer must have at most 1 due to cap
493        assert!(layers.iter().all(|layer| layer.len() <= 1));
494        // Union of all layers equals the subset
495        let mut all: Vec<u32> = layers.into_iter().flatten().collect();
496        all.sort();
497        assert_eq!(all, vec![1, 2, 3, 4]);
498    }
499
500    #[test]
501    fn layers_unbounded_expected_first_layer() {
502        let mut g = SimpleGraph::default();
503        g.add_edge(1, 3);
504        g.add_edge(2, 3); // zeros: 1,2,4
505        let nodes = [1, 2, 3, 4];
506        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
507        pk.rebuild_full(&g);
508        let layers = pk.layers_for(&g, &nodes, None);
509        assert!(!layers.is_empty());
510        // First layer should contain exactly {1,2,4} in some deterministic order
511        let mut got = layers[0].clone();
512        got.sort();
513        assert_eq!(got, vec![1, 2, 4]);
514        assert_eq!(layers[1], vec![3]);
515    }
516
517    #[test]
518    fn apply_bulk_adds_then_layers() {
519        let mut g = SimpleGraph::default();
520        let nodes = [1, 2, 3];
521        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
522        pk.rebuild_full(&g);
523        g.add_edge(1, 2);
524        g.add_edge(2, 3);
525        let _stats = pk.apply_bulk(&g, &[], &[(1, 2), (2, 3)]).unwrap();
526        let layers = pk.layers_for(&g, &nodes, None);
527        assert_eq!(layers.len(), 3);
528    }
529
530    #[test]
531    fn budget_config_is_respected_in_api_surface() {
532        let g = SimpleGraph::default();
533        let mut pk = DynamicTopo::new(
534            [1u32, 2, 3, 4, 5],
535            PkConfig {
536                visit_budget: 1,
537                compaction_interval_ops: 10,
538            },
539        );
540        pk.rebuild_full(&g);
541        assert_eq!(pk.cfg.visit_budget, 1);
542    }
543
544    #[test]
545    fn ensure_nodes_appends_missing() {
546        let g = SimpleGraph::default();
547        let mut pk = DynamicTopo::new([1u32], PkConfig::default());
548        pk.ensure_nodes([2u32, 3u32]);
549        let order = pk.topo_order();
550        assert_eq!(order.len(), 3);
551        assert!(idx(order, 1) < idx(order, 2));
552        assert!(idx(order, 2) < idx(order, 3));
553    }
554
555    #[test]
556    fn compact_ranks_keeps_nodes() {
557        let g = SimpleGraph::default();
558        let nodes = [3, 1, 4, 2];
559        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
560        pk.rebuild_full(&g);
561        pk.compact_ranks();
562        let mut order = pk.topo_order().to_vec();
563        order.sort();
564        assert_eq!(order, vec![1, 2, 3, 4]);
565    }
566
567    #[test]
568    fn remove_edge_does_not_change_order() {
569        let mut g = SimpleGraph::default();
570        g.add_edge(1, 3);
571        let nodes = [1, 2, 3];
572        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
573        pk.rebuild_full(&g);
574        let before = pk.topo_order().to_vec();
575        pk.remove_edge(2, 1); // unrelated removal
576        let after = pk.topo_order().to_vec();
577        assert_eq!(before, after);
578    }
579
580    #[test]
581    fn apply_bulk_mixed_removes_and_adds() {
582        // Start with a chain 1->2->3; then remove (1,2) and add (1,3)
583        let mut g = SimpleGraph::default();
584        g.add_edge(1, 2);
585        g.add_edge(2, 3);
586        let nodes = [1, 2, 3, 4];
587        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
588        pk.rebuild_full(&g);
589
590        // Prime PK with current edges
591        let _ = pk.apply_bulk(&g, &[], &[(1, 2), (2, 3)]).unwrap();
592
593        // Mutate graph: remove (1,2), add (1,3)
594        g.remove_edge(1, 2);
595        g.add_edge(1, 3);
596        let stats = pk.apply_bulk(&g, &[(1, 2)], &[(1, 3)]).unwrap();
597        // After changes, at minimum 1 must precede 3; 2 can be anywhere relative to 1
598        let order = pk.topo_order().to_vec();
599        assert!(idx(&order, 1) < idx(&order, 3));
600        // stats existence was verified by unwrap(); nothing else to assert here.
601    }
602
603    #[test]
604    fn layers_for_subset_only() {
605        // Full graph: 1->2, 2->3, 4->5; subset: [2,3,5]
606        let mut g = SimpleGraph::default();
607        g.add_edge(1, 2);
608        g.add_edge(2, 3);
609        g.add_edge(4, 5);
610        let nodes = [1, 2, 3, 4, 5];
611        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
612        pk.rebuild_full(&g);
613        // Build layers only for subset; expect 2 before 3; 5 has indegree 1 from 4 (outside subset) so indegree 0 in subset
614        let subset = vec![2, 3, 5];
615        let layers = pk.layers_for(&g, &subset, None);
616        // Flatten and compare membership equals subset
617        let mut flat: Vec<u32> = layers.iter().flatten().copied().collect();
618        flat.sort();
619        assert_eq!(flat, vec![2, 3, 5]);
620        // 2 should be in a layer before 3
621        let pos2 = layers.iter().position(|lay| lay.contains(&2)).unwrap();
622        let pos3 = layers.iter().position(|lay| lay.contains(&3)).unwrap();
623        assert!(pos2 < pos3);
624        // 5 should be in the first layer (no in-subset predecessors)
625        assert!(layers[0].contains(&5));
626    }
627
628    #[test]
629    fn compact_ranks_repeated_stability() {
630        // After establishing some ordering pressure, repeated compactions shouldn't change order
631        let mut g = SimpleGraph::default();
632        g.add_edge(1, 3);
633        g.add_edge(2, 3);
634        let nodes = [1, 2, 3, 4];
635        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
636        pk.rebuild_full(&g);
637        let _ = pk.try_add_edge(&g, 4, 2); // introduce a back-edge reorder (4 before 2)
638        let baseline = pk.topo_order().to_vec();
639        for _ in 0..10 {
640            pk.compact_ranks();
641            assert_eq!(baseline, pk.topo_order());
642        }
643    }
644
645    #[test]
646    fn compact_ranks_is_stable_repeated() {
647        let g = SimpleGraph::default();
648        let nodes = [1, 2, 3, 4];
649        let mut pk = DynamicTopo::new(nodes, PkConfig::default());
650        pk.rebuild_full(&g);
651        // Create a back-edge to reorder: 3->2
652        let _ = pk.try_add_edge(&g, 3, 2).unwrap();
653        let before = pk.topo_order().to_vec();
654        for _ in 0..5 {
655            pk.compact_ranks();
656            assert_eq!(pk.topo_order(), &before);
657        }
658    }
659}