1use super::certificate::{CertificateVersion, PlanHash};
8use super::{EClassId, EGraph, ENode, PlanDag, PlanId};
9use std::collections::BTreeMap;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
19pub struct PlanCost {
20 pub allocations: u64,
22 pub cancel_checkpoints: u64,
24 pub obligation_pressure: u64,
26 pub critical_path: u64,
28}
29
30impl PlanCost {
31 pub const ZERO: Self = Self {
33 allocations: 0,
34 cancel_checkpoints: 0,
35 obligation_pressure: 0,
36 critical_path: 0,
37 };
38
39 pub const UNKNOWN: Self = Self {
41 allocations: u64::MAX,
42 cancel_checkpoints: u64::MAX,
43 obligation_pressure: u64::MAX,
44 critical_path: u64::MAX,
45 };
46
47 pub const LEAF: Self = Self {
49 allocations: 1, cancel_checkpoints: 0,
51 obligation_pressure: 0,
52 critical_path: 1,
53 };
54
55 #[must_use]
57 #[allow(clippy::should_implement_trait)]
58 pub fn add(self, other: Self) -> Self {
59 Self {
60 allocations: self.allocations.saturating_add(other.allocations),
61 cancel_checkpoints: self
62 .cancel_checkpoints
63 .saturating_add(other.cancel_checkpoints),
64 obligation_pressure: self
65 .obligation_pressure
66 .saturating_add(other.obligation_pressure),
67 critical_path: self.critical_path.max(other.critical_path),
68 }
69 }
70
71 #[must_use]
73 pub fn sequential(self, other: Self) -> Self {
74 Self {
75 allocations: self.allocations.saturating_add(other.allocations),
76 cancel_checkpoints: self
77 .cancel_checkpoints
78 .saturating_add(other.cancel_checkpoints),
79 obligation_pressure: self
80 .obligation_pressure
81 .saturating_add(other.obligation_pressure),
82 critical_path: self.critical_path.saturating_add(other.critical_path),
83 }
84 }
85
86 #[must_use]
88 pub fn total(&self) -> u64 {
89 self.critical_path
91 .saturating_mul(1000)
92 .saturating_add(self.cancel_checkpoints.saturating_mul(100))
93 .saturating_add(self.obligation_pressure.saturating_mul(10))
94 .saturating_add(self.allocations)
95 }
96}
97
98impl PartialOrd for PlanCost {
99 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
100 Some(self.cmp(other))
101 }
102}
103
104impl Ord for PlanCost {
105 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
106 self.total().cmp(&other.total())
107 }
108}
109
110impl std::fmt::Display for PlanCost {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 write!(
113 f,
114 "alloc={} cancel={} obl={} depth={}",
115 self.allocations, self.cancel_checkpoints, self.obligation_pressure, self.critical_path
116 )
117 }
118}
119
120#[derive(Debug)]
126pub struct Extractor<'a> {
127 egraph: &'a mut EGraph,
128 costs: BTreeMap<EClassId, PlanCost>,
130 best_node: BTreeMap<EClassId, ENode>,
132}
133
134impl<'a> Extractor<'a> {
135 pub fn new(egraph: &'a mut EGraph) -> Self {
137 Self {
138 egraph,
139 costs: BTreeMap::new(),
140 best_node: BTreeMap::new(),
141 }
142 }
143
144 pub fn extract(&mut self, root: EClassId) -> (PlanDag, ExtractionCertificate) {
149 self.compute_cost(root);
151
152 let mut dag = PlanDag::new();
154 let mut id_map: BTreeMap<EClassId, PlanId> = BTreeMap::new();
155
156 let dag_root = self.build_plan_node(root, &mut dag, &mut id_map);
157 dag.set_root(dag_root);
158
159 let cost = self
160 .costs
161 .get(&self.egraph.canonical_id(root))
162 .copied()
163 .unwrap_or(PlanCost::ZERO);
164
165 let cert = ExtractionCertificate {
166 version: CertificateVersion::CURRENT,
167 root_class: root,
168 cost,
169 plan_hash: PlanHash::of(&dag),
170 node_count: dag.nodes.len(),
171 };
172
173 (dag, cert)
174 }
175
176 fn compute_cost(&mut self, id: EClassId) -> PlanCost {
178 let canonical = self.egraph.canonical_id(id);
179
180 if let Some(&cost) = self.costs.get(&canonical) {
181 return cost;
182 }
183
184 let Some(nodes) = self.egraph.class_nodes_cloned(canonical) else {
186 return PlanCost::ZERO;
187 };
188
189 if nodes.is_empty() {
190 self.costs.insert(canonical, PlanCost::ZERO);
191 return PlanCost::ZERO;
192 }
193
194 let mut best_cost = PlanCost {
196 allocations: u64::MAX,
197 cancel_checkpoints: u64::MAX,
198 obligation_pressure: u64::MAX,
199 critical_path: u64::MAX,
200 };
201 let mut best: Option<ENode> = None;
202
203 for node in nodes {
204 let cost = self.node_cost(&node);
205 if cost.total() < best_cost.total()
206 || (cost.total() == best_cost.total() && best.is_none())
207 {
208 best_cost = cost;
209 best = Some(node);
210 }
211 }
212
213 self.costs.insert(canonical, best_cost);
214 if let Some(node) = best {
215 self.best_node.insert(canonical, node);
216 }
217
218 best_cost
219 }
220
221 fn node_cost(&mut self, node: &ENode) -> PlanCost {
223 match node {
224 ENode::Leaf { label } => {
225 let mut cost = PlanCost::LEAF;
226 if label.starts_with("obl:") {
227 cost.obligation_pressure = 1;
228 }
229 cost
230 }
231 ENode::Join { children } => {
232 let mut cost = PlanCost::ZERO;
233 for child in children {
234 let child_cost = self.compute_cost(*child);
235 cost = cost.add(child_cost);
236 }
237 cost.allocations = cost.allocations.saturating_add(1);
239 cost
240 }
241 ENode::Race { children } => {
242 let mut cost = PlanCost::ZERO;
243 for child in children {
244 let child_cost = self.compute_cost(*child);
245 cost = cost.add(child_cost);
246 }
247 cost.cancel_checkpoints = cost.cancel_checkpoints.saturating_add(1);
249 cost.allocations = cost.allocations.saturating_add(1);
251 cost
252 }
253 ENode::Timeout { child, duration: _ } => {
254 let mut cost = self.compute_cost(*child);
255 cost.allocations = cost.allocations.saturating_add(1);
257 cost.critical_path = cost.critical_path.saturating_add(1);
258 cost
259 }
260 }
261 }
262
263 fn build_plan_node(
265 &mut self,
266 id: EClassId,
267 dag: &mut PlanDag,
268 id_map: &mut BTreeMap<EClassId, PlanId>,
269 ) -> PlanId {
270 let canonical = self.egraph.canonical_id(id);
271
272 if let Some(&plan_id) = id_map.get(&canonical) {
273 return plan_id;
274 }
275
276 let node = self
277 .best_node
278 .get(&canonical)
279 .cloned()
280 .expect("best_node computed for all reachable classes");
281
282 let plan_id = match &node {
283 ENode::Leaf { label } => dag.leaf(label.as_str()),
284 ENode::Join { children } => {
285 let child_ids: Vec<PlanId> = children
286 .iter()
287 .map(|c| self.build_plan_node(*c, dag, id_map))
288 .collect();
289 dag.join(child_ids)
290 }
291 ENode::Race { children } => {
292 let child_ids: Vec<PlanId> = children
293 .iter()
294 .map(|c| self.build_plan_node(*c, dag, id_map))
295 .collect();
296 dag.race(child_ids)
297 }
298 ENode::Timeout { child, duration } => {
299 let child_id = self.build_plan_node(*child, dag, id_map);
300 dag.timeout(child_id, *duration)
301 }
302 };
303
304 id_map.insert(canonical, plan_id);
305 plan_id
306 }
307}
308
309#[derive(Debug, Clone)]
317pub struct ExtractionCertificate {
318 pub version: CertificateVersion,
320 pub root_class: EClassId,
322 pub cost: PlanCost,
324 pub plan_hash: PlanHash,
326 pub node_count: usize,
328}
329
330impl ExtractionCertificate {
331 pub fn verify(&self, dag: &PlanDag) -> Result<(), ExtractionVerifyError> {
333 if self.version != CertificateVersion::CURRENT {
334 return Err(ExtractionVerifyError::VersionMismatch {
335 expected: CertificateVersion::CURRENT.number(),
336 found: self.version.number(),
337 });
338 }
339
340 let actual_hash = PlanHash::of(dag);
341 if self.plan_hash != actual_hash {
342 return Err(ExtractionVerifyError::HashMismatch {
343 expected: self.plan_hash.value(),
344 actual: actual_hash.value(),
345 });
346 }
347
348 if self.node_count != dag.nodes.len() {
349 return Err(ExtractionVerifyError::NodeCountMismatch {
350 expected: self.node_count,
351 actual: dag.nodes.len(),
352 });
353 }
354
355 Ok(())
356 }
357}
358
359#[derive(Debug, Clone, PartialEq, Eq)]
361pub enum ExtractionVerifyError {
362 VersionMismatch {
364 expected: u32,
366 found: u32,
368 },
369 HashMismatch {
371 expected: u64,
373 actual: u64,
375 },
376 NodeCountMismatch {
378 expected: usize,
380 actual: usize,
382 },
383}
384
385#[cfg(test)]
390mod tests {
391 use super::*;
392 use crate::test_utils::init_test_logging;
393 use std::time::Duration;
394
395 fn init_test() {
396 init_test_logging();
397 }
398
399 #[test]
400 fn extract_single_leaf() {
401 init_test();
402 let mut eg = EGraph::new();
403 let a = eg.add_leaf("a");
404
405 let mut extractor = Extractor::new(&mut eg);
406 let (dag, cert) = extractor.extract(a);
407
408 assert_eq!(dag.nodes.len(), 1);
409 assert!(cert.verify(&dag).is_ok());
410 assert_eq!(cert.cost.allocations, 1);
411 assert_eq!(cert.cost.critical_path, 1);
412 }
413
414 #[test]
415 fn extract_join_of_leaves() {
416 init_test();
417 let mut eg = EGraph::new();
418 let a = eg.add_leaf("a");
419 let b = eg.add_leaf("b");
420 let join = eg.add_join(vec![a, b]);
421
422 let mut extractor = Extractor::new(&mut eg);
423 let (dag, cert) = extractor.extract(join);
424
425 assert_eq!(dag.nodes.len(), 3);
426 assert!(cert.verify(&dag).is_ok());
427 assert_eq!(cert.cost.allocations, 3);
429 assert_eq!(cert.cost.critical_path, 1);
431 }
432
433 #[test]
434 fn extract_race_adds_cancel_checkpoint() {
435 init_test();
436 let mut eg = EGraph::new();
437 let a = eg.add_leaf("a");
438 let b = eg.add_leaf("b");
439 let race = eg.add_race(vec![a, b]);
440
441 let mut extractor = Extractor::new(&mut eg);
442 let (dag, cert) = extractor.extract(race);
443
444 assert_eq!(dag.nodes.len(), 3);
445 assert!(cert.verify(&dag).is_ok());
446 assert_eq!(cert.cost.cancel_checkpoints, 1);
447 }
448
449 #[test]
450 fn extract_obligation_pressure() {
451 init_test();
452 let mut eg = EGraph::new();
453 let obl = eg.add_leaf("obl:permit");
454 let plain = eg.add_leaf("compute");
455 let join = eg.add_join(vec![obl, plain]);
456
457 let mut extractor = Extractor::new(&mut eg);
458 let (dag, cert) = extractor.extract(join);
459
460 assert_eq!(dag.nodes.len(), 3);
461 assert!(cert.verify(&dag).is_ok());
462 assert_eq!(cert.cost.obligation_pressure, 1);
463 }
464
465 #[test]
466 fn extract_nested_critical_path() {
467 init_test();
468 let mut eg = EGraph::new();
469 let a = eg.add_leaf("a");
470 let t1 = eg.add_timeout(a, Duration::from_secs(5));
471 let t2 = eg.add_timeout(t1, Duration::from_secs(10));
472
473 let mut extractor = Extractor::new(&mut eg);
474 let (dag, cert) = extractor.extract(t2);
475
476 assert_eq!(dag.nodes.len(), 3);
477 assert!(cert.verify(&dag).is_ok());
478 assert_eq!(cert.cost.critical_path, 3);
480 }
481
482 #[test]
483 fn extraction_is_deterministic() {
484 init_test();
485 let mut eg = EGraph::new();
486 let a = eg.add_leaf("a");
487 let b = eg.add_leaf("b");
488 let c = eg.add_leaf("c");
489 let j1 = eg.add_join(vec![a, b]);
490 let r = eg.add_race(vec![j1, c]);
491
492 let mut extractor1 = Extractor::new(&mut eg);
493 let (dag1, cert1) = extractor1.extract(r);
494
495 let mut extractor2 = Extractor::new(&mut eg);
497 let (dag2, cert2) = extractor2.extract(r);
498
499 assert_eq!(cert1.plan_hash, cert2.plan_hash);
500 assert_eq!(cert1.cost, cert2.cost);
501 assert_eq!(dag1.nodes.len(), dag2.nodes.len());
502 }
503
504 #[test]
505 fn extract_after_merge_picks_best() {
506 init_test();
507 let mut eg = EGraph::new();
508 let a = eg.add_leaf("a");
509 let b = eg.add_leaf("b");
510 let c = eg.add_leaf("c");
511
512 let j1 = eg.add_join(vec![a, b, c]);
514 let inner_join = eg.add_join(vec![a, b]);
515 let j2 = eg.add_join(vec![inner_join, c]);
516
517 eg.merge(j1, j2);
519
520 let mut extractor = Extractor::new(&mut eg);
521 let (dag, cert) = extractor.extract(j1);
522
523 assert!(cert.verify(&dag).is_ok());
525 assert_eq!(cert.cost.allocations, 4); }
528
529 #[test]
530 fn cost_total_ordering() {
531 init_test();
532 let low = PlanCost {
533 allocations: 10,
534 cancel_checkpoints: 0,
535 obligation_pressure: 0,
536 critical_path: 1,
537 };
538 let high = PlanCost {
539 allocations: 1,
540 cancel_checkpoints: 0,
541 obligation_pressure: 0,
542 critical_path: 10,
543 };
544
545 assert!(low.total() < high.total());
547 }
548
549 #[test]
550 fn cost_display() {
551 init_test();
552 let cost = PlanCost {
553 allocations: 5,
554 cancel_checkpoints: 2,
555 obligation_pressure: 1,
556 critical_path: 3,
557 };
558 let display = format!("{cost}");
559 assert!(display.contains("alloc=5"));
560 assert!(display.contains("cancel=2"));
561 assert!(display.contains("obl=1"));
562 assert!(display.contains("depth=3"));
563 }
564
565 #[test]
566 fn certificate_version_mismatch() {
567 init_test();
568 let mut eg = EGraph::new();
569 let a = eg.add_leaf("a");
570
571 let mut extractor = Extractor::new(&mut eg);
572 let (dag, mut cert) = extractor.extract(a);
573
574 cert.version = CertificateVersion::from_number(99);
575 let result = cert.verify(&dag);
576 assert!(matches!(
577 result,
578 Err(ExtractionVerifyError::VersionMismatch { .. })
579 ));
580 }
581
582 #[test]
583 fn certificate_hash_mismatch() {
584 init_test();
585 let mut eg = EGraph::new();
586 let a = eg.add_leaf("a");
587
588 let mut extractor = Extractor::new(&mut eg);
589 let (mut dag, cert) = extractor.extract(a);
590
591 dag.leaf("extra");
593
594 let result = cert.verify(&dag);
595 assert!(matches!(
596 result,
597 Err(ExtractionVerifyError::HashMismatch { .. })
598 ));
599 }
600
601 #[test]
604 fn plan_cost_debug_copy_default() {
605 let cost = PlanCost::default();
606 assert_eq!(cost.allocations, 0);
607 assert_eq!(cost.cancel_checkpoints, 0);
608 assert_eq!(cost.obligation_pressure, 0);
609 assert_eq!(cost.critical_path, 0);
610
611 let dbg = format!("{cost:?}");
612 assert!(dbg.contains("PlanCost"));
613
614 let cost2 = cost;
616 assert_eq!(cost, cost2);
617
618 let cost3 = cost;
620 assert_eq!(cost, cost3);
621 }
622
623 #[test]
624 fn plan_cost_constants() {
625 assert_eq!(PlanCost::ZERO.total(), 0);
626 assert_eq!(PlanCost::ZERO.allocations, 0);
627
628 assert_eq!(PlanCost::LEAF.allocations, 1);
629 assert_eq!(PlanCost::LEAF.critical_path, 1);
630 assert_eq!(PlanCost::LEAF.cancel_checkpoints, 0);
631
632 assert_eq!(PlanCost::UNKNOWN.allocations, u64::MAX);
634 assert_eq!(PlanCost::UNKNOWN.critical_path, u64::MAX);
635 }
636
637 #[test]
638 fn plan_cost_add_sequential() {
639 let a = PlanCost {
640 allocations: 2,
641 cancel_checkpoints: 1,
642 obligation_pressure: 0,
643 critical_path: 3,
644 };
645 let b = PlanCost {
646 allocations: 3,
647 cancel_checkpoints: 0,
648 obligation_pressure: 1,
649 critical_path: 5,
650 };
651
652 let sum = a.add(b);
654 assert_eq!(sum.allocations, 5);
655 assert_eq!(sum.cancel_checkpoints, 1);
656 assert_eq!(sum.obligation_pressure, 1);
657 assert_eq!(sum.critical_path, 5); let seq = a.sequential(b);
661 assert_eq!(seq.allocations, 5);
662 assert_eq!(seq.critical_path, 8); }
664
665 #[test]
666 fn extraction_certificate_debug_clone() {
667 let mut eg = EGraph::new();
668 let a = eg.add_leaf("x");
669 let mut ext = Extractor::new(&mut eg);
670 let (_dag, cert) = ext.extract(a);
671
672 let dbg = format!("{cert:?}");
673 assert!(dbg.contains("ExtractionCertificate"));
674
675 let cloned = cert.clone();
676 assert_eq!(cloned.node_count, cert.node_count);
677 assert_eq!(cloned.cost, cert.cost);
678 }
679
680 #[test]
681 fn extraction_verify_error_debug_clone_eq() {
682 let e1 = ExtractionVerifyError::VersionMismatch {
683 expected: 1,
684 found: 2,
685 };
686 let e2 = ExtractionVerifyError::HashMismatch {
687 expected: 10,
688 actual: 20,
689 };
690 let e3 = ExtractionVerifyError::NodeCountMismatch {
691 expected: 5,
692 actual: 3,
693 };
694
695 let dbg1 = format!("{e1:?}");
696 assert!(dbg1.contains("VersionMismatch"));
697 let dbg2 = format!("{e2:?}");
698 assert!(dbg2.contains("HashMismatch"));
699 let dbg3 = format!("{e3:?}");
700 assert!(dbg3.contains("NodeCountMismatch"));
701
702 let e1c = e1.clone();
704 assert_eq!(e1, e1c);
705 assert_ne!(e1, e2);
706 }
707}