Skip to main content

asupersync/plan/
extractor.rs

1//! Deterministic best-plan extraction with cost model.
2//!
3//! Chooses an optimized representative from an e-graph using a deterministic
4//! cost model. The extraction algorithm is greedy and produces stable output
5//! given the same e-graph structure.
6
7use super::certificate::{CertificateVersion, PlanHash};
8use super::{EClassId, EGraph, ENode, PlanDag, PlanId};
9use std::collections::BTreeMap;
10
11// ===========================================================================
12// Cost model
13// ===========================================================================
14
15/// Cost components for a plan node.
16///
17/// All costs are additive and deterministic. Lower is better.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
19pub struct PlanCost {
20    /// Estimated allocations (heap objects created).
21    pub allocations: u64,
22    /// Cancel checkpoints (race nodes that need loser draining).
23    pub cancel_checkpoints: u64,
24    /// Obligation pressure (pending obligations that must resolve).
25    pub obligation_pressure: u64,
26    /// Critical path length (Foata depth - longest sequential chain).
27    pub critical_path: u64,
28}
29
30impl PlanCost {
31    /// Zero cost.
32    pub const ZERO: Self = Self {
33        allocations: 0,
34        cancel_checkpoints: 0,
35        obligation_pressure: 0,
36        critical_path: 0,
37    };
38
39    /// Sentinel cost for unknown nodes.
40    pub const UNKNOWN: Self = Self {
41        allocations: u64::MAX,
42        cancel_checkpoints: u64::MAX,
43        obligation_pressure: u64::MAX,
44        critical_path: u64::MAX,
45    };
46
47    /// Cost of a leaf node.
48    pub const LEAF: Self = Self {
49        allocations: 1, // One task allocation
50        cancel_checkpoints: 0,
51        obligation_pressure: 0,
52        critical_path: 1,
53    };
54
55    /// Add costs together (for parallel/join composition).
56    #[must_use]
57    #[allow(clippy::should_implement_trait)]
58    pub fn add(self, other: Self) -> Self {
59        Self {
60            allocations: self.allocations.saturating_add(other.allocations),
61            cancel_checkpoints: self
62                .cancel_checkpoints
63                .saturating_add(other.cancel_checkpoints),
64            obligation_pressure: self
65                .obligation_pressure
66                .saturating_add(other.obligation_pressure),
67            critical_path: self.critical_path.max(other.critical_path),
68        }
69    }
70
71    /// Sequential cost (critical path is sum, not max).
72    #[must_use]
73    pub fn sequential(self, other: Self) -> Self {
74        Self {
75            allocations: self.allocations.saturating_add(other.allocations),
76            cancel_checkpoints: self
77                .cancel_checkpoints
78                .saturating_add(other.cancel_checkpoints),
79            obligation_pressure: self
80                .obligation_pressure
81                .saturating_add(other.obligation_pressure),
82            critical_path: self.critical_path.saturating_add(other.critical_path),
83        }
84    }
85
86    /// Total scalar cost for comparison (weighted sum).
87    #[must_use]
88    pub fn total(&self) -> u64 {
89        // Weight critical path heavily, then cancel checkpoints, then allocations
90        self.critical_path
91            .saturating_mul(1000)
92            .saturating_add(self.cancel_checkpoints.saturating_mul(100))
93            .saturating_add(self.obligation_pressure.saturating_mul(10))
94            .saturating_add(self.allocations)
95    }
96}
97
98impl PartialOrd for PlanCost {
99    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
100        Some(self.cmp(other))
101    }
102}
103
104impl Ord for PlanCost {
105    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
106        self.total().cmp(&other.total())
107    }
108}
109
110impl std::fmt::Display for PlanCost {
111    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112        write!(
113            f,
114            "alloc={} cancel={} obl={} depth={}",
115            self.allocations, self.cancel_checkpoints, self.obligation_pressure, self.critical_path
116        )
117    }
118}
119
120// ===========================================================================
121// Extractor
122// ===========================================================================
123
124/// Extracts the best plan from an e-graph class.
125#[derive(Debug)]
126pub struct Extractor<'a> {
127    egraph: &'a mut EGraph,
128    /// Best cost for each class (memoized).
129    costs: BTreeMap<EClassId, PlanCost>,
130    /// Best e-node for each class.
131    best_node: BTreeMap<EClassId, ENode>,
132}
133
134impl<'a> Extractor<'a> {
135    /// Creates a new extractor for the given e-graph.
136    pub fn new(egraph: &'a mut EGraph) -> Self {
137        Self {
138            egraph,
139            costs: BTreeMap::new(),
140            best_node: BTreeMap::new(),
141        }
142    }
143
144    /// Extracts the best plan for a class and returns it as a `PlanDag`.
145    ///
146    /// The extraction is deterministic: given the same e-graph structure,
147    /// it always produces the same `PlanDag`.
148    pub fn extract(&mut self, root: EClassId) -> (PlanDag, ExtractionCertificate) {
149        // Compute costs for all reachable classes
150        self.compute_cost(root);
151
152        // Build the plan DAG from the best nodes
153        let mut dag = PlanDag::new();
154        let mut id_map: BTreeMap<EClassId, PlanId> = BTreeMap::new();
155
156        let dag_root = self.build_plan_node(root, &mut dag, &mut id_map);
157        dag.set_root(dag_root);
158
159        let cost = self
160            .costs
161            .get(&self.egraph.canonical_id(root))
162            .copied()
163            .unwrap_or(PlanCost::ZERO);
164
165        let cert = ExtractionCertificate {
166            version: CertificateVersion::CURRENT,
167            root_class: root,
168            cost,
169            plan_hash: PlanHash::of(&dag),
170            node_count: dag.nodes.len(),
171        };
172
173        (dag, cert)
174    }
175
176    /// Computes the best cost for a class (memoized, bottom-up).
177    fn compute_cost(&mut self, id: EClassId) -> PlanCost {
178        let canonical = self.egraph.canonical_id(id);
179
180        if let Some(&cost) = self.costs.get(&canonical) {
181            return cost;
182        }
183
184        // Get all nodes in this class (resolved from arena)
185        let Some(nodes) = self.egraph.class_nodes_cloned(canonical) else {
186            return PlanCost::ZERO;
187        };
188
189        if nodes.is_empty() {
190            self.costs.insert(canonical, PlanCost::ZERO);
191            return PlanCost::ZERO;
192        }
193
194        // Find the best node in this class
195        let mut best_cost = PlanCost {
196            allocations: u64::MAX,
197            cancel_checkpoints: u64::MAX,
198            obligation_pressure: u64::MAX,
199            critical_path: u64::MAX,
200        };
201        let mut best: Option<ENode> = None;
202
203        for node in nodes {
204            let cost = self.node_cost(&node);
205            if cost.total() < best_cost.total()
206                || (cost.total() == best_cost.total() && best.is_none())
207            {
208                best_cost = cost;
209                best = Some(node);
210            }
211        }
212
213        self.costs.insert(canonical, best_cost);
214        if let Some(node) = best {
215            self.best_node.insert(canonical, node);
216        }
217
218        best_cost
219    }
220
221    /// Computes the cost of a single e-node.
222    fn node_cost(&mut self, node: &ENode) -> PlanCost {
223        match node {
224            ENode::Leaf { label } => {
225                let mut cost = PlanCost::LEAF;
226                if label.starts_with("obl:") {
227                    cost.obligation_pressure = 1;
228                }
229                cost
230            }
231            ENode::Join { children } => {
232                let mut cost = PlanCost::ZERO;
233                for child in children {
234                    let child_cost = self.compute_cost(*child);
235                    cost = cost.add(child_cost);
236                }
237                // Add one allocation for the join combinator
238                cost.allocations = cost.allocations.saturating_add(1);
239                cost
240            }
241            ENode::Race { children } => {
242                let mut cost = PlanCost::ZERO;
243                for child in children {
244                    let child_cost = self.compute_cost(*child);
245                    cost = cost.add(child_cost);
246                }
247                // Race adds a cancel checkpoint
248                cost.cancel_checkpoints = cost.cancel_checkpoints.saturating_add(1);
249                // Add one allocation for the race combinator
250                cost.allocations = cost.allocations.saturating_add(1);
251                cost
252            }
253            ENode::Timeout { child, duration: _ } => {
254                let mut cost = self.compute_cost(*child);
255                // Timeout adds one allocation and increments critical path
256                cost.allocations = cost.allocations.saturating_add(1);
257                cost.critical_path = cost.critical_path.saturating_add(1);
258                cost
259            }
260        }
261    }
262
263    /// Builds a `PlanNode` from the best e-node for a class.
264    fn build_plan_node(
265        &mut self,
266        id: EClassId,
267        dag: &mut PlanDag,
268        id_map: &mut BTreeMap<EClassId, PlanId>,
269    ) -> PlanId {
270        let canonical = self.egraph.canonical_id(id);
271
272        if let Some(&plan_id) = id_map.get(&canonical) {
273            return plan_id;
274        }
275
276        let node = self
277            .best_node
278            .get(&canonical)
279            .cloned()
280            .expect("best_node computed for all reachable classes");
281
282        let plan_id = match &node {
283            ENode::Leaf { label } => dag.leaf(label.as_str()),
284            ENode::Join { children } => {
285                let child_ids: Vec<PlanId> = children
286                    .iter()
287                    .map(|c| self.build_plan_node(*c, dag, id_map))
288                    .collect();
289                dag.join(child_ids)
290            }
291            ENode::Race { children } => {
292                let child_ids: Vec<PlanId> = children
293                    .iter()
294                    .map(|c| self.build_plan_node(*c, dag, id_map))
295                    .collect();
296                dag.race(child_ids)
297            }
298            ENode::Timeout { child, duration } => {
299                let child_id = self.build_plan_node(*child, dag, id_map);
300                dag.timeout(child_id, *duration)
301            }
302        };
303
304        id_map.insert(canonical, plan_id);
305        plan_id
306    }
307}
308
309// ===========================================================================
310// Extraction certificate
311// ===========================================================================
312
313/// Certificate for a plan extraction.
314///
315/// Records the root class, computed cost, and plan hash for verification.
316#[derive(Debug, Clone)]
317pub struct ExtractionCertificate {
318    /// Schema version.
319    pub version: CertificateVersion,
320    /// Root class that was extracted.
321    pub root_class: EClassId,
322    /// Computed cost of the extracted plan.
323    pub cost: PlanCost,
324    /// Stable hash of the extracted plan DAG.
325    pub plan_hash: PlanHash,
326    /// Number of nodes in the extracted plan.
327    pub node_count: usize,
328}
329
330impl ExtractionCertificate {
331    /// Verifies that the certificate matches the given plan DAG.
332    pub fn verify(&self, dag: &PlanDag) -> Result<(), ExtractionVerifyError> {
333        if self.version != CertificateVersion::CURRENT {
334            return Err(ExtractionVerifyError::VersionMismatch {
335                expected: CertificateVersion::CURRENT.number(),
336                found: self.version.number(),
337            });
338        }
339
340        let actual_hash = PlanHash::of(dag);
341        if self.plan_hash != actual_hash {
342            return Err(ExtractionVerifyError::HashMismatch {
343                expected: self.plan_hash.value(),
344                actual: actual_hash.value(),
345            });
346        }
347
348        if self.node_count != dag.nodes.len() {
349            return Err(ExtractionVerifyError::NodeCountMismatch {
350                expected: self.node_count,
351                actual: dag.nodes.len(),
352            });
353        }
354
355        Ok(())
356    }
357}
358
359/// Error from extraction verification.
360#[derive(Debug, Clone, PartialEq, Eq)]
361pub enum ExtractionVerifyError {
362    /// Schema version mismatch.
363    VersionMismatch {
364        /// Expected version.
365        expected: u32,
366        /// Found version.
367        found: u32,
368    },
369    /// Plan hash mismatch.
370    HashMismatch {
371        /// Expected hash.
372        expected: u64,
373        /// Actual hash.
374        actual: u64,
375    },
376    /// Node count mismatch.
377    NodeCountMismatch {
378        /// Expected count.
379        expected: usize,
380        /// Actual count.
381        actual: usize,
382    },
383}
384
385// ===========================================================================
386// Tests
387// ===========================================================================
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392    use crate::test_utils::init_test_logging;
393    use std::time::Duration;
394
395    fn init_test() {
396        init_test_logging();
397    }
398
399    #[test]
400    fn extract_single_leaf() {
401        init_test();
402        let mut eg = EGraph::new();
403        let a = eg.add_leaf("a");
404
405        let mut extractor = Extractor::new(&mut eg);
406        let (dag, cert) = extractor.extract(a);
407
408        assert_eq!(dag.nodes.len(), 1);
409        assert!(cert.verify(&dag).is_ok());
410        assert_eq!(cert.cost.allocations, 1);
411        assert_eq!(cert.cost.critical_path, 1);
412    }
413
414    #[test]
415    fn extract_join_of_leaves() {
416        init_test();
417        let mut eg = EGraph::new();
418        let a = eg.add_leaf("a");
419        let b = eg.add_leaf("b");
420        let join = eg.add_join(vec![a, b]);
421
422        let mut extractor = Extractor::new(&mut eg);
423        let (dag, cert) = extractor.extract(join);
424
425        assert_eq!(dag.nodes.len(), 3);
426        assert!(cert.verify(&dag).is_ok());
427        // 2 leaves + 1 join = 3 allocations
428        assert_eq!(cert.cost.allocations, 3);
429        // Critical path is max of children = 1
430        assert_eq!(cert.cost.critical_path, 1);
431    }
432
433    #[test]
434    fn extract_race_adds_cancel_checkpoint() {
435        init_test();
436        let mut eg = EGraph::new();
437        let a = eg.add_leaf("a");
438        let b = eg.add_leaf("b");
439        let race = eg.add_race(vec![a, b]);
440
441        let mut extractor = Extractor::new(&mut eg);
442        let (dag, cert) = extractor.extract(race);
443
444        assert_eq!(dag.nodes.len(), 3);
445        assert!(cert.verify(&dag).is_ok());
446        assert_eq!(cert.cost.cancel_checkpoints, 1);
447    }
448
449    #[test]
450    fn extract_obligation_pressure() {
451        init_test();
452        let mut eg = EGraph::new();
453        let obl = eg.add_leaf("obl:permit");
454        let plain = eg.add_leaf("compute");
455        let join = eg.add_join(vec![obl, plain]);
456
457        let mut extractor = Extractor::new(&mut eg);
458        let (dag, cert) = extractor.extract(join);
459
460        assert_eq!(dag.nodes.len(), 3);
461        assert!(cert.verify(&dag).is_ok());
462        assert_eq!(cert.cost.obligation_pressure, 1);
463    }
464
465    #[test]
466    fn extract_nested_critical_path() {
467        init_test();
468        let mut eg = EGraph::new();
469        let a = eg.add_leaf("a");
470        let t1 = eg.add_timeout(a, Duration::from_secs(5));
471        let t2 = eg.add_timeout(t1, Duration::from_secs(10));
472
473        let mut extractor = Extractor::new(&mut eg);
474        let (dag, cert) = extractor.extract(t2);
475
476        assert_eq!(dag.nodes.len(), 3);
477        assert!(cert.verify(&dag).is_ok());
478        // Leaf (1) + timeout (1) + timeout (1) = 3
479        assert_eq!(cert.cost.critical_path, 3);
480    }
481
482    #[test]
483    fn extraction_is_deterministic() {
484        init_test();
485        let mut eg = EGraph::new();
486        let a = eg.add_leaf("a");
487        let b = eg.add_leaf("b");
488        let c = eg.add_leaf("c");
489        let j1 = eg.add_join(vec![a, b]);
490        let r = eg.add_race(vec![j1, c]);
491
492        let mut extractor1 = Extractor::new(&mut eg);
493        let (dag1, cert1) = extractor1.extract(r);
494
495        // Extract again (new extractor, same egraph)
496        let mut extractor2 = Extractor::new(&mut eg);
497        let (dag2, cert2) = extractor2.extract(r);
498
499        assert_eq!(cert1.plan_hash, cert2.plan_hash);
500        assert_eq!(cert1.cost, cert2.cost);
501        assert_eq!(dag1.nodes.len(), dag2.nodes.len());
502    }
503
504    #[test]
505    fn extract_after_merge_picks_best() {
506        init_test();
507        let mut eg = EGraph::new();
508        let a = eg.add_leaf("a");
509        let b = eg.add_leaf("b");
510        let c = eg.add_leaf("c");
511
512        // Two different representations of the same thing
513        let j1 = eg.add_join(vec![a, b, c]);
514        let inner_join = eg.add_join(vec![a, b]);
515        let j2 = eg.add_join(vec![inner_join, c]);
516
517        // Merge them into the same class
518        eg.merge(j1, j2);
519
520        let mut extractor = Extractor::new(&mut eg);
521        let (dag, cert) = extractor.extract(j1);
522
523        // Should pick the flatter representation (lower cost)
524        assert!(cert.verify(&dag).is_ok());
525        // The flat join is cheaper (fewer allocations)
526        assert_eq!(cert.cost.allocations, 4); // 3 leaves + 1 join
527    }
528
529    #[test]
530    fn cost_total_ordering() {
531        init_test();
532        let low = PlanCost {
533            allocations: 10,
534            cancel_checkpoints: 0,
535            obligation_pressure: 0,
536            critical_path: 1,
537        };
538        let high = PlanCost {
539            allocations: 1,
540            cancel_checkpoints: 0,
541            obligation_pressure: 0,
542            critical_path: 10,
543        };
544
545        // Critical path dominates
546        assert!(low.total() < high.total());
547    }
548
549    #[test]
550    fn cost_display() {
551        init_test();
552        let cost = PlanCost {
553            allocations: 5,
554            cancel_checkpoints: 2,
555            obligation_pressure: 1,
556            critical_path: 3,
557        };
558        let display = format!("{cost}");
559        assert!(display.contains("alloc=5"));
560        assert!(display.contains("cancel=2"));
561        assert!(display.contains("obl=1"));
562        assert!(display.contains("depth=3"));
563    }
564
565    #[test]
566    fn certificate_version_mismatch() {
567        init_test();
568        let mut eg = EGraph::new();
569        let a = eg.add_leaf("a");
570
571        let mut extractor = Extractor::new(&mut eg);
572        let (dag, mut cert) = extractor.extract(a);
573
574        cert.version = CertificateVersion::from_number(99);
575        let result = cert.verify(&dag);
576        assert!(matches!(
577            result,
578            Err(ExtractionVerifyError::VersionMismatch { .. })
579        ));
580    }
581
582    #[test]
583    fn certificate_hash_mismatch() {
584        init_test();
585        let mut eg = EGraph::new();
586        let a = eg.add_leaf("a");
587
588        let mut extractor = Extractor::new(&mut eg);
589        let (mut dag, cert) = extractor.extract(a);
590
591        // Mutate the DAG
592        dag.leaf("extra");
593
594        let result = cert.verify(&dag);
595        assert!(matches!(
596            result,
597            Err(ExtractionVerifyError::HashMismatch { .. })
598        ));
599    }
600
601    // Pure data-type tests (wave 37 – CyanBarn)
602
603    #[test]
604    fn plan_cost_debug_copy_default() {
605        let cost = PlanCost::default();
606        assert_eq!(cost.allocations, 0);
607        assert_eq!(cost.cancel_checkpoints, 0);
608        assert_eq!(cost.obligation_pressure, 0);
609        assert_eq!(cost.critical_path, 0);
610
611        let dbg = format!("{cost:?}");
612        assert!(dbg.contains("PlanCost"));
613
614        // Copy
615        let cost2 = cost;
616        assert_eq!(cost, cost2);
617
618        // Clone
619        let cost3 = cost;
620        assert_eq!(cost, cost3);
621    }
622
623    #[test]
624    fn plan_cost_constants() {
625        assert_eq!(PlanCost::ZERO.total(), 0);
626        assert_eq!(PlanCost::ZERO.allocations, 0);
627
628        assert_eq!(PlanCost::LEAF.allocations, 1);
629        assert_eq!(PlanCost::LEAF.critical_path, 1);
630        assert_eq!(PlanCost::LEAF.cancel_checkpoints, 0);
631
632        // UNKNOWN is sentinel
633        assert_eq!(PlanCost::UNKNOWN.allocations, u64::MAX);
634        assert_eq!(PlanCost::UNKNOWN.critical_path, u64::MAX);
635    }
636
637    #[test]
638    fn plan_cost_add_sequential() {
639        let a = PlanCost {
640            allocations: 2,
641            cancel_checkpoints: 1,
642            obligation_pressure: 0,
643            critical_path: 3,
644        };
645        let b = PlanCost {
646            allocations: 3,
647            cancel_checkpoints: 0,
648            obligation_pressure: 1,
649            critical_path: 5,
650        };
651
652        // add: critical_path = max
653        let sum = a.add(b);
654        assert_eq!(sum.allocations, 5);
655        assert_eq!(sum.cancel_checkpoints, 1);
656        assert_eq!(sum.obligation_pressure, 1);
657        assert_eq!(sum.critical_path, 5); // max(3,5)
658
659        // sequential: critical_path = sum
660        let seq = a.sequential(b);
661        assert_eq!(seq.allocations, 5);
662        assert_eq!(seq.critical_path, 8); // 3+5
663    }
664
665    #[test]
666    fn extraction_certificate_debug_clone() {
667        let mut eg = EGraph::new();
668        let a = eg.add_leaf("x");
669        let mut ext = Extractor::new(&mut eg);
670        let (_dag, cert) = ext.extract(a);
671
672        let dbg = format!("{cert:?}");
673        assert!(dbg.contains("ExtractionCertificate"));
674
675        let cloned = cert.clone();
676        assert_eq!(cloned.node_count, cert.node_count);
677        assert_eq!(cloned.cost, cert.cost);
678    }
679
680    #[test]
681    fn extraction_verify_error_debug_clone_eq() {
682        let e1 = ExtractionVerifyError::VersionMismatch {
683            expected: 1,
684            found: 2,
685        };
686        let e2 = ExtractionVerifyError::HashMismatch {
687            expected: 10,
688            actual: 20,
689        };
690        let e3 = ExtractionVerifyError::NodeCountMismatch {
691            expected: 5,
692            actual: 3,
693        };
694
695        let dbg1 = format!("{e1:?}");
696        assert!(dbg1.contains("VersionMismatch"));
697        let dbg2 = format!("{e2:?}");
698        assert!(dbg2.contains("HashMismatch"));
699        let dbg3 = format!("{e3:?}");
700        assert!(dbg3.contains("NodeCountMismatch"));
701
702        // Clone + PartialEq
703        let e1c = e1.clone();
704        assert_eq!(e1, e1c);
705        assert_ne!(e1, e2);
706    }
707}