use std::fmt::Write;
use std::time::Duration;
use super::PlanDag;
use super::rewrite::RewriteRule;
#[derive(Debug)]
pub struct PlanFixture {
pub name: &'static str,
pub intent: &'static str,
pub dag: PlanDag,
pub expected_rules: Vec<RewriteRule>,
pub expected_step_count: usize,
}
#[must_use]
pub fn all_fixtures() -> Vec<PlanFixture> {
vec![
simple_join_race_dedup(),
three_way_race_of_joins(),
nested_timeout_join_race(),
no_shared_child(),
single_branch_race(),
deep_chain_no_rewrite(),
shared_non_leaf_conservative(),
shared_non_leaf_associative(),
diamond_join_race(),
timeout_wrapping_dedup(),
independent_subtrees(),
race_of_leaves(),
race_cancel_with_timeout(),
nested_race_cancel_cascade(),
timeout_race_dedup_cancel(),
race_obligation_cancel(),
]
}
fn simple_join_race_dedup() -> PlanFixture {
let mut dag = PlanDag::new();
let shared = dag.leaf("shared");
let left = dag.leaf("left");
let right = dag.leaf("right");
let join_a = dag.join(vec![shared, left]);
let join_b = dag.join(vec![shared, right]);
let race = dag.race(vec![join_a, join_b]);
dag.set_root(race);
PlanFixture {
name: "simple_join_race_dedup",
intent: "Basic DedupRaceJoin: Race[Join[s,a], Join[s,b]] -> Join[s, Race[a,b]]",
dag,
expected_rules: vec![RewriteRule::DedupRaceJoin],
expected_step_count: 1,
}
}
fn three_way_race_of_joins() -> PlanFixture {
let mut dag = PlanDag::new();
let shared = dag.leaf("shared");
let a = dag.leaf("a");
let b = dag.leaf("b");
let c = dag.leaf("c");
let j1 = dag.join(vec![shared, a]);
let j2 = dag.join(vec![shared, b]);
let j3 = dag.join(vec![shared, c]);
let race = dag.race(vec![j1, j2, j3]);
dag.set_root(race);
PlanFixture {
name: "three_way_race_of_joins",
intent: "3-way race: conservative rejects non-binary race",
dag,
expected_rules: vec![],
expected_step_count: 0,
}
}
fn nested_timeout_join_race() -> PlanFixture {
let mut dag = PlanDag::new();
let shared = dag.leaf("shared");
let left = dag.leaf("left");
let right = dag.leaf("right");
let join_a = dag.join(vec![shared, left]);
let join_b = dag.join(vec![shared, right]);
let race = dag.race(vec![join_a, join_b]);
let timed = dag.timeout(race, Duration::from_secs(5));
dag.set_root(timed);
PlanFixture {
name: "nested_timeout_join_race",
intent: "Timeout[Race[Join[s,a], Join[s,b]]] -> Timeout[Join[s, Race[a,b]]]",
dag,
expected_rules: vec![RewriteRule::DedupRaceJoin],
expected_step_count: 1,
}
}
fn no_shared_child() -> PlanFixture {
let mut dag = PlanDag::new();
let a = dag.leaf("a");
let b = dag.leaf("b");
let c = dag.leaf("c");
let d = dag.leaf("d");
let j1 = dag.join(vec![a, b]);
let j2 = dag.join(vec![c, d]);
let race = dag.race(vec![j1, j2]);
dag.set_root(race);
PlanFixture {
name: "no_shared_child",
intent: "No shared child across joins; no rewrite fires",
dag,
expected_rules: vec![],
expected_step_count: 0,
}
}
fn single_branch_race() -> PlanFixture {
let mut dag = PlanDag::new();
let a = dag.leaf("a");
let b = dag.leaf("b");
let join = dag.join(vec![a, b]);
let race = dag.race(vec![join]);
dag.set_root(race);
PlanFixture {
name: "single_branch_race",
intent: "Single-branch race: DedupRaceJoin requires >= 2 children",
dag,
expected_rules: vec![],
expected_step_count: 0,
}
}
fn deep_chain_no_rewrite() -> PlanFixture {
let mut dag = PlanDag::new();
let a = dag.leaf("a");
let b = dag.leaf("b");
let c = dag.leaf("c");
let d = dag.leaf("d");
let j1 = dag.join(vec![a, b]);
let j2 = dag.join(vec![j1, c]);
let j3 = dag.join(vec![j2, d]);
dag.set_root(j3);
PlanFixture {
name: "deep_chain_no_rewrite",
intent: "Linear join chain with no race; no rewrite applicable",
dag,
expected_rules: vec![],
expected_step_count: 0,
}
}
fn shared_non_leaf_conservative() -> PlanFixture {
let mut dag = PlanDag::new();
let x = dag.leaf("x");
let y = dag.leaf("y");
let shared_join = dag.join(vec![x, y]);
let a = dag.leaf("a");
let b = dag.leaf("b");
let j1 = dag.join(vec![shared_join, a]);
let j2 = dag.join(vec![shared_join, b]);
let race = dag.race(vec![j1, j2]);
dag.set_root(race);
PlanFixture {
name: "shared_non_leaf_conservative",
intent: "Shared child is a Join (non-leaf); conservative policy rejects",
dag,
expected_rules: vec![],
expected_step_count: 0,
}
}
fn shared_non_leaf_associative() -> PlanFixture {
let mut dag = PlanDag::new();
let x = dag.leaf("x");
let y = dag.leaf("y");
let shared_join = dag.join(vec![x, y]);
let a = dag.leaf("a");
let b = dag.leaf("b");
let j1 = dag.join(vec![shared_join, a]);
let j2 = dag.join(vec![shared_join, b]);
let race = dag.race(vec![j1, j2]);
dag.set_root(race);
PlanFixture {
name: "shared_non_leaf_associative",
intent: "Shared non-leaf under AssumeAssociativeComm: rewrite fires",
dag,
expected_rules: vec![RewriteRule::DedupRaceJoin],
expected_step_count: 1,
}
}
fn diamond_join_race() -> PlanFixture {
let mut dag = PlanDag::new();
let a = dag.leaf("a");
let b = dag.leaf("b");
let c = dag.leaf("c");
let race = dag.race(vec![b, c]);
let join = dag.join(vec![a, race]);
dag.set_root(join);
PlanFixture {
name: "diamond_join_race",
intent: "Join[a, Race[b,c]]: already in deduped form; no rewrite",
dag,
expected_rules: vec![],
expected_step_count: 0,
}
}
fn timeout_wrapping_dedup() -> PlanFixture {
let mut dag = PlanDag::new();
let shared = dag.leaf("shared");
let a = dag.leaf("a");
let b = dag.leaf("b");
let j1 = dag.join(vec![shared, a]);
let j2 = dag.join(vec![shared, b]);
let race = dag.race(vec![j1, j2]);
let timed = dag.timeout(race, Duration::from_millis(500));
let outer = dag.leaf("outer");
let top = dag.join(vec![outer, timed]);
dag.set_root(top);
PlanFixture {
name: "timeout_wrapping_dedup",
intent: "Join[outer, Timeout[Race[Join[s,a], Join[s,b]]]]: inner race rewrites",
dag,
expected_rules: vec![RewriteRule::DedupRaceJoin],
expected_step_count: 1,
}
}
fn independent_subtrees() -> PlanFixture {
let mut dag = PlanDag::new();
let a = dag.leaf("a");
let b = dag.leaf("b");
let c = dag.leaf("c");
let d = dag.leaf("d");
let left = dag.join(vec![a, b]);
let right = dag.join(vec![c, d]);
let top = dag.join(vec![left, right]);
dag.set_root(top);
PlanFixture {
name: "independent_subtrees",
intent: "Two independent join subtrees; no race pattern",
dag,
expected_rules: vec![],
expected_step_count: 0,
}
}
fn race_of_leaves() -> PlanFixture {
let mut dag = PlanDag::new();
let a = dag.leaf("a");
let b = dag.leaf("b");
let c = dag.leaf("c");
let race = dag.race(vec![a, b, c]);
dag.set_root(race);
PlanFixture {
name: "race_of_leaves",
intent: "Race[a,b,c]: children aren't joins, no DedupRaceJoin",
dag,
expected_rules: vec![],
expected_step_count: 0,
}
}
fn race_cancel_with_timeout() -> PlanFixture {
let mut dag = PlanDag::new();
let fast = dag.leaf("fast");
let slow = dag.leaf("slow");
let timed_slow = dag.timeout(slow, Duration::from_secs(3));
let race = dag.race(vec![fast, timed_slow]);
dag.set_root(race);
PlanFixture {
name: "race_cancel_with_timeout",
intent: "Race[fast, Timeout[slow]]: loser cancelled, timeout interacts with cancel",
dag,
expected_rules: vec![],
expected_step_count: 0,
}
}
fn nested_race_cancel_cascade() -> PlanFixture {
let mut dag = PlanDag::new();
let a = dag.leaf("a");
let b = dag.leaf("b");
let c = dag.leaf("c");
let d = dag.leaf("d");
let inner_race1 = dag.race(vec![a, b]);
let inner_race2 = dag.race(vec![c, d]);
let outer_race = dag.race(vec![inner_race1, inner_race2]);
dag.set_root(outer_race);
PlanFixture {
name: "nested_race_cancel_cascade",
intent: "Race[Race[a,b], Race[c,d]]: cancel cascades through nested races",
dag,
expected_rules: vec![],
expected_step_count: 0,
}
}
fn timeout_race_dedup_cancel() -> PlanFixture {
let mut dag = PlanDag::new();
let shared = dag.leaf("shared");
let a = dag.leaf("a");
let b = dag.leaf("b");
let timed_a = dag.timeout(a, Duration::from_secs(2));
let timed_b = dag.timeout(b, Duration::from_secs(4));
let join_a = dag.join(vec![shared, timed_a]);
let join_b = dag.join(vec![shared, timed_b]);
let race = dag.race(vec![join_a, join_b]);
dag.set_root(race);
PlanFixture {
name: "timeout_race_dedup_cancel",
intent: "Race[Join[s,Timeout[a]], Join[s,Timeout[b]]]: dedup + cancel with timed leaves",
dag,
expected_rules: vec![RewriteRule::DedupRaceJoin],
expected_step_count: 1,
}
}
fn race_obligation_cancel() -> PlanFixture {
let mut dag = PlanDag::new();
let obl_permit = dag.leaf("obl:permit");
let obl_lock = dag.leaf("obl:lock");
let compute = dag.leaf("compute");
let join_permit = dag.join(vec![obl_permit, compute]);
let race = dag.race(vec![join_permit, obl_lock]);
dag.set_root(race);
PlanFixture {
name: "race_obligation_cancel",
intent: "Race[Join[obl:permit, compute], obl:lock]: cancel must not leak obligations",
dag,
expected_rules: vec![],
expected_step_count: 0,
}
}
use std::collections::{BTreeSet, HashMap};
use super::certificate::{PlanHash, RewriteCertificate, verify, verify_steps};
use super::extractor::PlanCost;
use super::rewrite::RewritePolicy;
use super::{PlanId, PlanNode};
#[derive(Debug, Clone)]
pub struct LabEquivalenceReport {
pub fixture_name: &'static str,
pub original_hash: PlanHash,
pub optimized_hash: PlanHash,
pub certificate: RewriteCertificate,
pub original_outcomes: BTreeSet<Vec<String>>,
pub optimized_outcomes: BTreeSet<Vec<String>>,
pub outcomes_equivalent: bool,
pub certificate_verified: bool,
pub steps_verified: bool,
}
impl LabEquivalenceReport {
#[must_use]
pub fn all_ok(&self) -> bool {
self.outcomes_equivalent && self.certificate_verified && self.steps_verified
}
#[must_use]
pub fn diff_summary(&self) -> Option<String> {
if self.all_ok() {
return None;
}
let mut out = format!("Fixture: {}\n", self.fixture_name);
if !self.outcomes_equivalent {
let _ = write!(
&mut out,
" OUTCOME MISMATCH:\n original: {:?}\n optimized: {:?}\n",
self.original_outcomes, self.optimized_outcomes
);
}
if !self.certificate_verified {
out.push_str(" CERTIFICATE HASH MISMATCH\n");
}
if !self.steps_verified {
out.push_str(" STEP VERIFICATION FAILED\n");
}
Some(out)
}
}
#[must_use]
pub fn outcome_sets(dag: &PlanDag, id: PlanId) -> BTreeSet<Vec<String>> {
let mut memo = HashMap::new();
outcome_sets_inner(dag, id, &mut memo)
}
fn outcome_sets_inner(
dag: &PlanDag,
id: PlanId,
memo: &mut HashMap<PlanId, BTreeSet<Vec<String>>>,
) -> BTreeSet<Vec<String>> {
if let Some(cached) = memo.get(&id) {
return cached.clone();
}
let Some(node) = dag.node(id) else {
return BTreeSet::new();
};
let result = match node {
PlanNode::Leaf { label } => {
let mut set = BTreeSet::new();
set.insert(vec![label.clone()]);
set
}
PlanNode::Join { children } => {
let mut acc = BTreeSet::new();
acc.insert(Vec::new());
for child in children {
let child_sets = outcome_sets_inner(dag, *child, memo);
let mut next = BTreeSet::new();
for base in &acc {
for child_set in &child_sets {
let mut merged = base.clone();
merged.extend(child_set.iter().cloned());
merged.sort();
merged.dedup();
next.insert(merged);
}
}
acc = next;
}
acc
}
PlanNode::Race { children } => {
let mut acc = BTreeSet::new();
for child in children {
let child_sets = outcome_sets_inner(dag, *child, memo);
acc.extend(child_sets);
}
acc
}
PlanNode::Timeout { child, .. } => outcome_sets_inner(dag, *child, memo),
};
memo.insert(id, result.clone());
result
}
#[must_use]
pub fn run_equivalence_harness(
mut fixture: PlanFixture,
policy: RewritePolicy,
rules: &[RewriteRule],
) -> LabEquivalenceReport {
let original_hash = PlanHash::of(&fixture.dag);
let original_outcomes = fixture
.dag
.root()
.map(|root| outcome_sets(&fixture.dag, root))
.unwrap_or_default();
let (_, certificate) = fixture.dag.apply_rewrites_certified(policy, rules);
let optimized_hash = PlanHash::of(&fixture.dag);
let optimized_outcomes = fixture
.dag
.root()
.map(|root| outcome_sets(&fixture.dag, root))
.unwrap_or_default();
let outcomes_equivalent = original_outcomes == optimized_outcomes;
let certificate_verified = verify(&certificate, &fixture.dag).is_ok();
let steps_verified = verify_steps(&certificate, &fixture.dag).is_ok();
LabEquivalenceReport {
fixture_name: fixture.name,
original_hash,
optimized_hash,
certificate,
original_outcomes,
optimized_outcomes,
outcomes_equivalent,
certificate_verified,
steps_verified,
}
}
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::lab::runtime::LabRuntime;
use crate::runtime::TaskHandle;
use crate::types::{Budget, CancelReason, TaskId};
struct LabYieldOnce {
yielded: bool,
}
impl Future for LabYieldOnce {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.yielded {
Poll::Ready(())
} else {
self.yielded = true;
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
fn lab_yield_once() -> LabYieldOnce {
LabYieldOnce { yielded: false }
}
enum LabJoinState {
Empty,
InFlight,
Ready(BTreeSet<String>),
}
struct SharedLabInner {
handle: parking_lot::Mutex<Option<TaskHandle<BTreeSet<String>>>>,
state: parking_lot::Mutex<LabJoinState>,
}
#[derive(Clone)]
struct SharedLabHandle {
inner: std::sync::Arc<SharedLabInner>,
}
impl SharedLabHandle {
fn new(handle: TaskHandle<BTreeSet<String>>) -> Self {
Self {
inner: std::sync::Arc::new(SharedLabInner {
handle: parking_lot::Mutex::new(Some(handle)),
state: parking_lot::Mutex::new(LabJoinState::Empty),
}),
}
}
async fn join(&self) -> BTreeSet<String> {
let cx = crate::cx::Cx::current().expect("cx set");
let i_am_joiner;
{
let mut state = self.inner.state.lock();
match &*state {
LabJoinState::Ready(result) => return result.clone(),
LabJoinState::InFlight => {
i_am_joiner = false;
}
LabJoinState::Empty => {
*state = LabJoinState::InFlight;
drop(state);
i_am_joiner = true;
}
}
}
if i_am_joiner {
struct Guard<'a> {
inner: &'a SharedLabInner,
handle: Option<TaskHandle<BTreeSet<String>>>,
}
impl Drop for Guard<'_> {
fn drop(&mut self) {
if let Some(h) = self.handle.take() {
*self.inner.handle.lock() = Some(h);
let mut state = self.inner.state.lock();
if matches!(*state, LabJoinState::InFlight) {
*state = LabJoinState::Empty;
}
}
}
}
let mut guard = Guard {
inner: &self.inner,
handle: Some(
self.inner
.handle
.lock()
.take()
.expect("join handle available"),
),
};
let result = guard
.handle
.as_mut()
.expect("handle must exist")
.join(&cx)
.await
.unwrap_or_default();
*self.inner.state.lock() = LabJoinState::Ready(result.clone());
result
} else {
loop {
{
let state = self.inner.state.lock();
if let LabJoinState::Ready(result) = &*state {
return result.clone();
}
}
if cx.checkpoint().is_err() {
return BTreeSet::new();
}
let now = cx
.timer_driver()
.map_or_else(crate::time::wall_now, |driver| driver.now());
crate::time::sleep(now, std::time::Duration::from_millis(10)).await;
}
}
}
#[allow(clippy::significant_drop_tightening)]
fn try_join_probe(&self) -> Option<BTreeSet<String>> {
let mut state = self.inner.state.lock();
match &*state {
LabJoinState::Ready(result) => Some(result.clone()),
LabJoinState::InFlight => None,
LabJoinState::Empty => {
*state = LabJoinState::InFlight;
drop(state);
let mut handle = self
.inner
.handle
.lock()
.take()
.expect("join handle available");
let join_result = handle.try_join();
*self.inner.handle.lock() = Some(handle);
let mut state = self.inner.state.lock();
match join_result {
Ok(Some(result)) => {
*state = LabJoinState::Ready(result.clone());
Some(result)
}
Ok(None) => {
*state = LabJoinState::Empty;
None
}
Err(_) => {
*state = LabJoinState::Ready(BTreeSet::new());
drop(state);
Some(BTreeSet::new())
}
}
}
}
}
fn abort_with_reason(&self, reason: CancelReason) {
if let Some(handle) = self.inner.handle.lock().as_ref() {
handle.abort_with_reason(reason);
}
}
}
#[must_use]
pub fn dag_has_fan_in(dag: &PlanDag) -> bool {
use super::PlanNode;
let Some(root) = dag.root() else {
return false;
};
let mut ref_counts = vec![0u32; dag.nodes.len()];
let mut seen = BTreeSet::new();
let mut stack = vec![root];
while let Some(id) = stack.pop() {
if !seen.insert(id) {
continue;
}
let Some(node) = dag.node(id) else {
continue;
};
match node {
PlanNode::Leaf { .. } => {}
PlanNode::Join { children } | PlanNode::Race { children } => {
for c in children {
if let Some(count) = ref_counts.get_mut(c.index()) {
*count += 1;
if *count > 1 {
return true;
}
}
stack.push(*c);
}
}
PlanNode::Timeout { child, .. } => {
if let Some(count) = ref_counts.get_mut(child.index()) {
*count += 1;
if *count > 1 {
return true;
}
}
stack.push(*child);
}
}
}
false
}
#[must_use]
pub fn execute_plan_in_lab(seed: u64, dag: &PlanDag) -> BTreeSet<String> {
crate::lab::runtime::test(seed, |runtime| execute_plan_in_lab_core(runtime, seed, dag))
}
fn execute_plan_in_lab_traced(seed: u64, dag: &PlanDag) -> (BTreeSet<String>, u64) {
let config = crate::lab::config::LabConfig::new(seed).trace_capacity(4096);
let mut runtime = LabRuntime::new(config);
let labels = execute_plan_in_lab_core(&mut runtime, seed, dag);
let events = runtime.trace().snapshot();
let fp = crate::trace::trace_fingerprint(&events);
(labels, fp)
}
#[allow(clippy::too_many_lines)]
fn execute_plan_in_lab_core(
runtime: &mut LabRuntime,
seed: u64,
dag: &PlanDag,
) -> BTreeSet<String> {
use super::PlanNode;
let root = dag.root().expect("dag has root");
let region = runtime.state.create_root_region(Budget::INFINITE);
let mut handles: Vec<Option<SharedLabHandle>> = Vec::new();
let mut task_ids: Vec<TaskId> = Vec::new();
for (idx, node) in dag.nodes.iter().enumerate() {
let (tid, raw_handle) = match node.clone() {
PlanNode::Leaf { label } => {
let yield_count = (seed as usize).wrapping_add(idx) % 4 + 1;
spawn_lab_leaf(runtime, region, label, yield_count)
}
PlanNode::Join { children } => {
let child_handles: Vec<_> = children
.iter()
.map(|c| {
handles[c.index()]
.as_ref()
.expect("child handle available")
.clone()
})
.collect();
spawn_lab_join(runtime, region, child_handles)
}
PlanNode::Race { children } => {
let child_handles: Vec<_> = children
.iter()
.map(|c| {
handles[c.index()]
.as_ref()
.expect("child handle available")
.clone()
})
.collect();
spawn_lab_race(runtime, region, child_handles)
}
PlanNode::Timeout { child, duration } => {
let child_handle = handles[child.index()]
.as_ref()
.expect("child handle available")
.clone();
spawn_lab_timeout(runtime, region, child_handle, duration)
}
};
task_ids.push(tid);
handles.push(Some(SharedLabHandle::new(raw_handle)));
}
{
let is_empty = {
let mut sched = runtime.scheduler.lock();
for tid in &task_ids {
sched.schedule(*tid, crate::types::Budget::INFINITE.priority);
}
sched.is_empty()
};
crate::tracing_compat::trace!(
"plan fixtures scheduled {} tasks (scheduler_empty={})",
task_ids.len(),
is_empty
);
#[cfg(not(feature = "tracing-integration"))]
let _ = is_empty;
}
let steps = runtime.run_with_auto_advance().steps;
crate::tracing_compat::trace!(
"plan fixtures first run finished in {} steps (quiescent={})",
steps,
runtime.is_quiescent()
);
#[cfg(not(feature = "tracing-integration"))]
let _ = steps;
let mut attempts = 0;
while !runtime.is_quiescent() && attempts < 3 {
{
let mut sched = runtime.scheduler.lock();
for (_, record) in runtime.state.tasks_iter() {
if record.is_runnable() {
sched.schedule(record.id, record.sched_priority);
}
}
}
runtime.run_with_auto_advance();
attempts += 1;
}
if !runtime.is_quiescent() {
let runnable: Vec<_> = runtime
.state
.tasks_iter()
.filter(|(_, r)| r.is_runnable())
.map(|(_, r)| format!("{:?}={:?}", r.id, r.state))
.collect();
let total = runtime.state.tasks_iter().count();
panic!(
"runtime must be quiescent after {} steps (attempts={}): runnable=[{}], total_tasks={}",
runtime.steps(),
attempts,
runnable.join(", "),
total,
);
}
handles[root.index()]
.as_ref()
.expect("root handle")
.try_join_probe()
.expect("root should be ready")
}
fn spawn_lab_leaf(
runtime: &mut LabRuntime,
region: crate::types::RegionId,
label: String,
yield_count: usize,
) -> (TaskId, TaskHandle<BTreeSet<String>>) {
let future = async move {
lab_yield_once().await;
crate::time::sleep(
crate::types::Time::ZERO,
Duration::from_secs(yield_count as u64),
)
.await;
let mut set = BTreeSet::new();
set.insert(label);
set
};
runtime
.state
.create_task(region, Budget::INFINITE, future)
.expect("spawn leaf")
}
fn spawn_lab_join(
runtime: &mut LabRuntime,
region: crate::types::RegionId,
child_handles: Vec<SharedLabHandle>,
) -> (TaskId, TaskHandle<BTreeSet<String>>) {
let future = async move {
let mut all_labels = BTreeSet::new();
for handle in &child_handles {
all_labels.extend(handle.join().await);
}
all_labels
};
runtime
.state
.create_task(region, Budget::INFINITE, future)
.expect("spawn join driver")
}
fn spawn_lab_race(
runtime: &mut LabRuntime,
region: crate::types::RegionId,
child_handles: Vec<SharedLabHandle>,
) -> (TaskId, TaskHandle<BTreeSet<String>>) {
let future = async move {
if child_handles.is_empty() {
return BTreeSet::new();
}
if child_handles.len() == 1 {
return child_handles[0].join().await;
}
let cx = crate::cx::Cx::current().expect("cx set");
let (winner_idx, winner_result) = loop {
if let Some((i, result)) = child_handles
.iter()
.enumerate()
.find_map(|(i, handle)| handle.try_join_probe().map(|result| (i, result)))
{
break (i, result);
}
if cx.checkpoint().is_err() {
return BTreeSet::new();
}
let now = cx
.timer_driver()
.map_or_else(crate::time::wall_now, |d| d.now());
crate::time::sleep(now, std::time::Duration::from_millis(10)).await;
};
for (j, handle) in child_handles.iter().enumerate() {
if j != winner_idx {
handle.abort_with_reason(CancelReason::race_loser());
}
}
*child_handles[winner_idx].inner.state.lock() = LabJoinState::Ready(winner_result.clone());
for (j, handle) in child_handles.iter().enumerate() {
if j != winner_idx {
let _ = handle.join().await;
}
}
winner_result
};
runtime
.state
.create_task(region, Budget::INFINITE, future)
.expect("spawn race driver")
}
fn spawn_lab_timeout(
runtime: &mut LabRuntime,
region: crate::types::RegionId,
child_handle: SharedLabHandle,
duration: Duration,
) -> (TaskId, TaskHandle<BTreeSet<String>>) {
let future = async move {
if let Ok(result) =
crate::time::timeout(crate::types::Time::ZERO, duration, child_handle.join()).await
{
result
} else {
child_handle.abort_with_reason(CancelReason::timeout());
let _ = child_handle.join().await;
BTreeSet::new()
}
};
runtime
.state
.create_task(region, Budget::INFINITE, future)
.expect("spawn timeout driver")
}
#[derive(Debug, Clone)]
#[allow(clippy::struct_excessive_bools)]
pub struct LabDynamicEquivalenceReport {
pub fixture_name: &'static str,
pub certificate: RewriteCertificate,
pub certificate_verified: bool,
pub steps_verified: bool,
pub static_outcomes_equivalent: bool,
pub seeds: Vec<u64>,
pub per_seed_results: Vec<(BTreeSet<String>, BTreeSet<String>, bool)>,
pub dynamic_outcomes_equivalent: bool,
pub original_outcome_universe: BTreeSet<Vec<String>>,
pub optimized_outcome_universe: BTreeSet<Vec<String>>,
pub universes_match: bool,
}
impl LabDynamicEquivalenceReport {
#[must_use]
pub fn all_ok(&self) -> bool {
self.certificate_verified
&& self.steps_verified
&& self.static_outcomes_equivalent
&& self.dynamic_outcomes_equivalent
&& self.universes_match
}
#[must_use]
pub fn failure_summary(&self) -> Option<String> {
if self.all_ok() {
return None;
}
let mut out = format!("Fixture: {}\n", self.fixture_name);
if !self.certificate_verified {
out.push_str(" CERTIFICATE HASH MISMATCH\n");
}
if !self.steps_verified {
out.push_str(" STEP VERIFICATION FAILED\n");
}
if !self.static_outcomes_equivalent {
out.push_str(" STATIC OUTCOME MISMATCH\n");
}
if !self.dynamic_outcomes_equivalent {
out.push_str(" DYNAMIC OUTCOME MISMATCH (per-seed):\n");
for (i, (orig, opt, ok)) in self.per_seed_results.iter().enumerate() {
if !ok {
let _ = writeln!(
&mut out,
" seed {}: original={:?} optimized={:?}",
self.seeds[i], orig, opt
);
}
}
}
if !self.universes_match {
let _ = write!(
&mut out,
" UNIVERSE MISMATCH:\n original: {:?}\n optimized: {:?}\n",
self.original_outcome_universe, self.optimized_outcome_universe
);
}
Some(out)
}
}
#[must_use]
pub fn run_lab_dynamic_equivalence(
fixture: PlanFixture,
policy: RewritePolicy,
rules: &[RewriteRule],
seeds: &[u64],
) -> LabDynamicEquivalenceReport {
let original_dag = fixture.dag.clone();
let original_static = original_dag
.root()
.map(|r| outcome_sets(&original_dag, r))
.unwrap_or_default();
let mut optimized_dag = fixture.dag;
let (_, certificate) = optimized_dag.apply_rewrites_certified(policy, rules);
let optimized_static = optimized_dag
.root()
.map(|r| outcome_sets(&optimized_dag, r))
.unwrap_or_default();
let static_outcomes_equivalent = original_static == optimized_static;
let certificate_verified = verify(&certificate, &optimized_dag).is_ok();
let steps_verified = verify_steps(&certificate, &optimized_dag).is_ok();
let mut per_seed_results = Vec::with_capacity(seeds.len());
let mut original_universe = BTreeSet::new();
let mut optimized_universe = BTreeSet::new();
let mut all_dynamic_ok = true;
for &seed in seeds {
let orig_labels = execute_plan_in_lab(seed, &original_dag);
let opt_labels = execute_plan_in_lab(seed, &optimized_dag);
let ok = orig_labels == opt_labels;
if !ok {
all_dynamic_ok = false;
}
original_universe.insert(orig_labels.iter().cloned().collect::<Vec<_>>());
optimized_universe.insert(opt_labels.iter().cloned().collect::<Vec<_>>());
per_seed_results.push((orig_labels, opt_labels, ok));
}
let universes_match = original_universe == optimized_universe;
LabDynamicEquivalenceReport {
fixture_name: fixture.name,
certificate,
certificate_verified,
steps_verified,
static_outcomes_equivalent,
seeds: seeds.to_vec(),
per_seed_results,
dynamic_outcomes_equivalent: all_dynamic_ok,
original_outcome_universe: original_universe,
optimized_outcome_universe: optimized_universe,
universes_match,
}
}
fn compute_dag_cost(dag: &PlanDag) -> PlanCost {
use super::PlanNode;
fn recurse(dag: &PlanDag, id: PlanId, memo: &mut HashMap<PlanId, PlanCost>) -> PlanCost {
if let Some(&c) = memo.get(&id) {
return c;
}
let node = dag.node(id).expect("valid PlanId");
let cost = match node.clone() {
PlanNode::Leaf { .. } => PlanCost {
allocations: 1,
cancel_checkpoints: 0,
obligation_pressure: 0,
critical_path: 1,
},
PlanNode::Join { children } => {
let child_costs: Vec<_> = children.iter().map(|c| recurse(dag, *c, memo)).collect();
let allocs: u64 = child_costs.iter().map(|c| c.allocations).sum::<u64>() + 1;
let cp = child_costs
.iter()
.map(|c| c.critical_path)
.max()
.unwrap_or(0)
+ 1;
let obl: u64 = child_costs.iter().map(|c| c.obligation_pressure).sum();
PlanCost {
allocations: allocs,
cancel_checkpoints: 0,
obligation_pressure: obl,
critical_path: cp,
}
}
PlanNode::Race { children } => {
let child_costs: Vec<_> = children.iter().map(|c| recurse(dag, *c, memo)).collect();
let allocs: u64 = child_costs.iter().map(|c| c.allocations).sum::<u64>() + 1;
let cp = child_costs
.iter()
.map(|c| c.critical_path)
.min()
.unwrap_or(0)
+ 1;
let cancel_cps: u64 = child_costs
.iter()
.map(|c| c.cancel_checkpoints)
.sum::<u64>()
+ children.len() as u64;
PlanCost {
allocations: allocs,
cancel_checkpoints: cancel_cps,
obligation_pressure: 0,
critical_path: cp,
}
}
PlanNode::Timeout { child, .. } => {
let child_cost = recurse(dag, child, memo);
PlanCost {
allocations: child_cost.allocations + 1,
cancel_checkpoints: child_cost.cancel_checkpoints + 1,
obligation_pressure: child_cost.obligation_pressure + 1,
critical_path: child_cost.critical_path + 1,
}
}
};
memo.insert(id, cost);
cost
}
let Some(root) = dag.root() else {
return PlanCost::default();
};
let mut memo = HashMap::new();
recurse(dag, root, &mut memo)
}
#[derive(Debug, Clone)]
#[allow(clippy::struct_excessive_bools)]
pub struct E2ePipelineReport {
pub fixture_name: &'static str,
pub certificate_verified: bool,
pub steps_verified: bool,
pub outcomes_equivalent: bool,
pub extraction_equivalent: bool,
pub rewrite_extraction_equivalent: bool,
pub dynamic_outcomes_equivalent: bool,
pub certificate_fingerprint: u64,
pub original_cost: PlanCost,
pub optimized_cost: PlanCost,
pub original_trace_fingerprint: u64,
pub optimized_trace_fingerprint: u64,
pub dynamic_original_labels: BTreeSet<String>,
pub dynamic_optimized_labels: BTreeSet<String>,
pub rewrite_count: usize,
}
impl E2ePipelineReport {
#[must_use]
pub fn all_ok(&self) -> bool {
self.certificate_verified
&& self.steps_verified
&& self.outcomes_equivalent
&& self.extraction_equivalent
&& self.rewrite_extraction_equivalent
&& self.dynamic_outcomes_equivalent
}
#[must_use]
pub fn golden_fingerprint(&self) -> u64 {
const FNV_OFFSET: u64 = 0xcbf2_9ce4_8422_2325;
const FNV_PRIME: u64 = 0x0100_0000_01b3;
let mut h = FNV_OFFSET;
let mix = |h: &mut u64, v: u64| {
*h ^= v;
*h = h.wrapping_mul(FNV_PRIME);
};
mix(&mut h, self.certificate_fingerprint);
mix(&mut h, self.original_cost.total());
mix(&mut h, self.optimized_cost.total());
mix(&mut h, self.original_trace_fingerprint);
mix(&mut h, self.optimized_trace_fingerprint);
mix(&mut h, self.rewrite_count as u64);
for label in &self.dynamic_original_labels {
for byte in label.bytes() {
h ^= u64::from(byte);
h = h.wrapping_mul(FNV_PRIME);
}
}
h
}
#[must_use]
pub fn cost_delta(&self) -> i128 {
i128::from(self.original_cost.total()) - i128::from(self.optimized_cost.total())
}
}
#[must_use]
pub fn run_e2e_pipeline(
fixture: PlanFixture,
policy: RewritePolicy,
rules: &[RewriteRule],
) -> E2ePipelineReport {
use super::certificate::{verify, verify_steps};
let original_dag = fixture.dag.clone();
let original_cost = compute_dag_cost(&original_dag);
let original_static = original_dag
.root()
.map(|r| outcome_sets(&original_dag, r))
.unwrap_or_default();
let mut optimized_dag = fixture.dag;
let (report, certificate) = optimized_dag.apply_rewrites_certified(policy, rules);
let optimized_cost = compute_dag_cost(&optimized_dag);
let rewrite_count = report.steps().len();
let optimized_static = optimized_dag
.root()
.map(|r| outcome_sets(&optimized_dag, r))
.unwrap_or_default();
let outcomes_equivalent = original_static == optimized_static;
let certificate_verified = verify(&certificate, &optimized_dag).is_ok();
let steps_verified = verify_steps(&certificate, &optimized_dag).is_ok();
let certificate_fingerprint = certificate.fingerprint();
let extraction_equivalent = {
use super::extractor::Extractor;
let mut eg = crate::plan::EGraph::new();
let mut cache = HashMap::new();
original_dag.root().is_none_or(|root| {
let root_ec = dag_to_egraph_rec(&original_dag, root, &mut eg, &mut cache);
let (extracted, _) = Extractor::new(&mut eg)
.extract(root_ec)
.expect("original DAG should round-trip through extraction");
let extracted_outcomes = extracted
.root()
.map(|r| outcome_sets(&extracted, r))
.unwrap_or_default();
original_static == extracted_outcomes
})
};
let rewrite_extraction_equivalent = {
use super::extractor::Extractor;
let mut eg = crate::plan::EGraph::new();
let mut cache = HashMap::new();
optimized_dag.root().is_none_or(|root| {
let root_ec = dag_to_egraph_rec(&optimized_dag, root, &mut eg, &mut cache);
let (extracted, _) = Extractor::new(&mut eg)
.extract(root_ec)
.expect("optimized DAG should round-trip through extraction");
let extracted_outcomes = extracted
.root()
.map(|r| outcome_sets(&extracted, r))
.unwrap_or_default();
original_static == extracted_outcomes
})
};
let (dynamic_original_labels, original_trace_fingerprint) =
execute_plan_in_lab_traced(42, &original_dag);
let (dynamic_optimized_labels, optimized_trace_fingerprint) =
execute_plan_in_lab_traced(42, &optimized_dag);
let dynamic_outcomes_equivalent = dynamic_original_labels == dynamic_optimized_labels;
E2ePipelineReport {
fixture_name: fixture.name,
certificate_verified,
steps_verified,
outcomes_equivalent,
extraction_equivalent,
rewrite_extraction_equivalent,
dynamic_outcomes_equivalent,
certificate_fingerprint,
original_cost,
optimized_cost,
original_trace_fingerprint,
optimized_trace_fingerprint,
dynamic_original_labels,
dynamic_optimized_labels,
rewrite_count,
}
}
#[must_use]
pub fn run_e2e_pipeline_all(
policy: RewritePolicy,
rules: &[RewriteRule],
) -> Vec<E2ePipelineReport> {
all_fixtures()
.into_iter()
.map(|f| run_e2e_pipeline(f, policy, rules))
.collect()
}
fn dag_to_egraph_rec(
dag: &PlanDag,
id: PlanId,
eg: &mut crate::plan::EGraph,
cache: &mut HashMap<PlanId, crate::plan::EClassId>,
) -> crate::plan::EClassId {
if let Some(&ec) = cache.get(&id) {
return ec;
}
let node = dag.node(id).expect("valid PlanId");
let eclass = match node.clone() {
PlanNode::Leaf { label } => eg.add_leaf(label),
PlanNode::Join { children } => {
let ec: Vec<_> = children
.iter()
.map(|c| dag_to_egraph_rec(dag, *c, eg, cache))
.collect();
eg.add_join(ec)
}
PlanNode::Race { children } => {
let ec: Vec<_> = children
.iter()
.map(|c| dag_to_egraph_rec(dag, *c, eg, cache))
.collect();
eg.add_race(ec)
}
PlanNode::Timeout { child, duration } => {
let child_ec = dag_to_egraph_rec(dag, child, eg, cache);
eg.add_timeout(child_ec, duration)
}
};
cache.insert(id, eclass);
eclass
}
#[cfg(test)]
mod tests {
use super::*;
use crate::plan::certificate::{verify, verify_steps};
use crate::plan::rewrite::{RewritePolicy, RewriteRule};
use crate::test_utils::init_test_logging;
fn init_test() {
init_test_logging();
}
#[test]
fn all_fixtures_validate() {
init_test();
for fixture in all_fixtures() {
assert!(
fixture.dag.validate().is_ok(),
"fixture {} failed validation",
fixture.name
);
}
}
#[test]
fn fixture_count_at_least_10() {
init_test();
let fixtures = all_fixtures();
assert!(
fixtures.len() >= 10,
"need >= 10 fixtures, got {}",
fixtures.len()
);
}
#[test]
fn non_test_plan_fixture_paths_do_not_print_directly() {
init_test();
let source = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/plan/fixtures.rs"));
let non_test = source.split("#[cfg(test)]").next().unwrap_or(source);
assert!(
!non_test.contains("println!(") && !non_test.contains("eprintln!("),
"non-test plan fixture paths should use tracing instead of stdout/stderr"
);
}
#[test]
fn conservative_rewrites_match_expected_counts() {
init_test();
let rules = [RewriteRule::DedupRaceJoin];
for mut fixture in all_fixtures() {
if fixture.name == "shared_non_leaf_associative" {
continue;
}
let (report, cert) = fixture
.dag
.apply_rewrites_certified(RewritePolicy::conservative(), &rules);
assert_eq!(
report.steps().len(),
fixture.expected_step_count,
"fixture {}: expected {} steps, got {}",
fixture.name,
fixture.expected_step_count,
report.steps().len()
);
assert!(
verify(&cert, &fixture.dag).is_ok(),
"fixture {}: certificate verification failed",
fixture.name
);
assert!(
verify_steps(&cert, &fixture.dag).is_ok(),
"fixture {}: step verification failed",
fixture.name
);
}
}
#[test]
fn associative_policy_fires_on_non_leaf_shared() {
init_test();
let rules = [RewriteRule::DedupRaceJoin];
let fixtures = all_fixtures();
let fixture = fixtures
.iter()
.find(|f| f.name == "shared_non_leaf_associative")
.expect("fixture exists");
let mut dag = PlanDag::new();
let x = dag.leaf("x");
let y = dag.leaf("y");
let shared_join = dag.join(vec![x, y]);
let a = dag.leaf("a");
let b = dag.leaf("b");
let j1 = dag.join(vec![shared_join, a]);
let j2 = dag.join(vec![shared_join, b]);
let race = dag.race(vec![j1, j2]);
dag.set_root(race);
let (report, cert) = dag.apply_rewrites_certified(RewritePolicy::assume_all(), &rules);
assert_eq!(report.steps().len(), fixture.expected_step_count);
assert!(verify(&cert, &dag).is_ok());
}
#[test]
fn lab_equivalence_all_fixtures_conservative() {
init_test();
let rules = [RewriteRule::DedupRaceJoin];
for fixture in all_fixtures() {
if fixture.name == "shared_non_leaf_associative" {
continue;
}
let report = run_equivalence_harness(fixture, RewritePolicy::conservative(), &rules);
if let Some(diff) = report.diff_summary() {
panic!("equivalence failure:\n{diff}");
}
assert!(report.all_ok());
}
}
#[test]
#[allow(clippy::similar_names)]
fn rule_witness_golden_fingerprints() {
init_test();
let rules = [RewriteRule::DedupRaceJoin];
let mut f1 = simple_join_race_dedup();
let (_, cert1) = f1
.dag
.apply_rewrites_certified(RewritePolicy::conservative(), &rules);
assert_eq!(cert1.steps.len(), 1, "F1 must have exactly 1 rewrite step");
assert!(verify(&cert1, &f1.dag).is_ok(), "F1 cert must verify");
assert!(
verify_steps(&cert1, &f1.dag).is_ok(),
"F1 steps must verify"
);
let fp1 = cert1.fingerprint();
assert_ne!(fp1, 0, "fingerprint must be nonzero");
assert_ne!(cert1.before_hash, cert1.after_hash);
let mut f3 = nested_timeout_join_race();
let (_, cert3) = f3
.dag
.apply_rewrites_certified(RewritePolicy::conservative(), &rules);
assert_eq!(cert3.steps.len(), 1);
assert!(verify(&cert3, &f3.dag).is_ok());
assert!(verify_steps(&cert3, &f3.dag).is_ok());
let fp3 = cert3.fingerprint();
assert_ne!(
fp3, fp1,
"different fixtures must produce different fingerprints"
);
let mut fixture_timeout = timeout_wrapping_dedup();
let (_, cert_timeout) = fixture_timeout
.dag
.apply_rewrites_certified(RewritePolicy::conservative(), &rules);
assert_eq!(cert_timeout.steps.len(), 1);
assert!(verify(&cert_timeout, &fixture_timeout.dag).is_ok());
assert!(verify_steps(&cert_timeout, &fixture_timeout.dag).is_ok());
let fp_timeout = cert_timeout.fingerprint();
assert_ne!(fp_timeout, fp1);
assert_ne!(fp_timeout, fp3);
let mut fixture_no_shared_a = no_shared_child();
let (_, cert_no_shared_a) = fixture_no_shared_a
.dag
.apply_rewrites_certified(RewritePolicy::conservative(), &rules);
assert!(cert_no_shared_a.is_identity());
let mut fixture_no_shared_b = no_shared_child();
let (_, cert_no_shared_b) = fixture_no_shared_b
.dag
.apply_rewrites_certified(RewritePolicy::conservative(), &rules);
assert_eq!(
cert_no_shared_a.fingerprint(),
cert_no_shared_b.fingerprint()
);
}
#[test]
fn egraph_determinism_golden_hashes() {
use crate::plan::EGraph;
init_test();
let build = || {
let mut eg = EGraph::new();
let a = eg.add_leaf("a");
let b = eg.add_leaf("b");
let c = eg.add_leaf("c");
let join_ab = eg.add_join(vec![a, b]);
let join_ab2 = eg.add_join(vec![a, b]); let race_bc = eg.add_race(vec![b, c]);
let top = eg.add_join(vec![join_ab, race_bc]);
(eg, a, b, c, join_ab, join_ab2, race_bc, top)
};
let (mut eg1, a1, b1, _, j1, j1_dup, r1, t1) = build();
let (mut eg2, a2, b2, _, j2, j2_dup, r2, t2) = build();
assert_eq!(eg1.canonical_id(j1), eg1.canonical_id(j1_dup));
assert_eq!(eg2.canonical_id(j2), eg2.canonical_id(j2_dup));
assert_eq!(eg1.canonical_id(a1).index(), eg2.canonical_id(a2).index());
assert_eq!(eg1.canonical_id(b1).index(), eg2.canonical_id(b2).index());
assert_eq!(eg1.canonical_id(j1).index(), eg2.canonical_id(j2).index());
assert_eq!(eg1.canonical_id(r1).index(), eg2.canonical_id(r2).index());
assert_eq!(eg1.canonical_id(t1).index(), eg2.canonical_id(t2).index());
let merged1 = eg1.merge(b1, a1);
let merged2 = eg2.merge(b2, a2);
assert_eq!(merged1.index(), merged2.index());
assert_eq!(merged1, a1); }
#[test]
fn cancel_fixtures_present_and_valid() {
init_test();
let fixtures = all_fixtures();
let cancel_names: Vec<&str> = fixtures
.iter()
.filter(|f| {
f.name.contains("cancel")
|| f.intent.contains("cancel")
|| f.intent.contains("Cancel")
})
.map(|f| f.name)
.collect();
assert!(
cancel_names.len() >= 4,
"need >= 4 cancel-aware fixtures, got {}: {:?}",
cancel_names.len(),
cancel_names
);
for fixture in &fixtures {
assert!(
fixture.dag.validate().is_ok(),
"cancel fixture {} failed validation",
fixture.name
);
}
}
#[test]
fn lab_equivalence_all_fixtures_all_rules() {
init_test();
let all_rules = [
RewriteRule::DedupRaceJoin,
RewriteRule::JoinAssoc,
RewriteRule::RaceAssoc,
RewriteRule::JoinCommute,
RewriteRule::RaceCommute,
RewriteRule::TimeoutMin,
];
for fixture in all_fixtures() {
if fixture.name == "shared_non_leaf_associative" {
continue;
}
let report =
run_equivalence_harness(fixture, RewritePolicy::conservative(), &all_rules);
assert!(
report.outcomes_equivalent,
"outcomes not equivalent for fixture {}",
report.fixture_name
);
}
}
#[test]
fn extraction_pipeline_equivalence() {
use crate::plan::PlanId;
use crate::plan::extractor::Extractor;
use std::collections::HashMap;
init_test();
for fixture in all_fixtures() {
let original_outcomes = fixture
.dag
.root()
.map(|root| outcome_sets(&fixture.dag, root))
.unwrap_or_default();
let mut eg = crate::plan::EGraph::new();
let mut cache: HashMap<PlanId, crate::plan::EClassId> = HashMap::new();
if let Some(root) = fixture.dag.root() {
let root_eclass = dag_to_egraph_rec(&fixture.dag, root, &mut eg, &mut cache);
let (extracted_dag, _cert) = Extractor::new(&mut eg)
.extract(root_eclass)
.expect("fixture DAG should extract successfully");
let extracted_outcomes = extracted_dag
.root()
.map(|r| outcome_sets(&extracted_dag, r))
.unwrap_or_default();
assert_eq!(
original_outcomes, extracted_outcomes,
"extraction changed outcomes for fixture {}",
fixture.name
);
}
}
}
fn dag_to_egraph_rec(
dag: &PlanDag,
id: PlanId,
eg: &mut crate::plan::EGraph,
cache: &mut HashMap<PlanId, crate::plan::EClassId>,
) -> crate::plan::EClassId {
if let Some(&ec) = cache.get(&id) {
return ec;
}
let node = dag.node(id).expect("valid PlanId");
let eclass = match node.clone() {
PlanNode::Leaf { label } => eg.add_leaf(label),
PlanNode::Join { children } => {
let ec: Vec<_> = children
.iter()
.map(|c| dag_to_egraph_rec(dag, *c, eg, cache))
.collect();
eg.add_join(ec)
}
PlanNode::Race { children } => {
let ec: Vec<_> = children
.iter()
.map(|c| dag_to_egraph_rec(dag, *c, eg, cache))
.collect();
eg.add_race(ec)
}
PlanNode::Timeout { child, duration } => {
let child_ec = dag_to_egraph_rec(dag, child, eg, cache);
eg.add_timeout(child_ec, duration)
}
};
cache.insert(id, eclass);
eclass
}
#[test]
fn extraction_after_rewrite_equivalence() {
use crate::plan::PlanId;
use crate::plan::extractor::Extractor;
use std::collections::HashMap;
init_test();
let rules = [RewriteRule::DedupRaceJoin];
for mut fixture in all_fixtures() {
if fixture.name == "shared_non_leaf_associative" {
continue;
}
let original_outcomes = fixture
.dag
.root()
.map(|root| outcome_sets(&fixture.dag, root))
.unwrap_or_default();
let (_report, _cert) = fixture
.dag
.apply_rewrites_certified(RewritePolicy::conservative(), &rules);
let mut eg = crate::plan::EGraph::new();
let mut cache: HashMap<PlanId, crate::plan::EClassId> = HashMap::new();
if let Some(root) = fixture.dag.root() {
let root_eclass = dag_to_egraph_rec(&fixture.dag, root, &mut eg, &mut cache);
let (extracted_dag, _cert) = Extractor::new(&mut eg)
.extract(root_eclass)
.expect("rewritten fixture DAG should extract successfully");
let extracted_outcomes = extracted_dag
.root()
.map(|r| outcome_sets(&extracted_dag, r))
.unwrap_or_default();
assert_eq!(
original_outcomes, extracted_outcomes,
"rewrite+extraction changed outcomes for fixture {}",
fixture.name
);
}
}
}
#[test]
fn lab_equivalence_deterministic_across_runs() {
init_test();
let rules = [RewriteRule::DedupRaceJoin];
let reports1: Vec<_> = all_fixtures()
.into_iter()
.filter(|f| f.name != "shared_non_leaf_associative")
.map(|f| run_equivalence_harness(f, RewritePolicy::conservative(), &rules))
.collect();
let reports2: Vec<_> = all_fixtures()
.into_iter()
.filter(|f| f.name != "shared_non_leaf_associative")
.map(|f| run_equivalence_harness(f, RewritePolicy::conservative(), &rules))
.collect();
assert_eq!(reports1.len(), reports2.len());
for (r1, r2) in reports1.iter().zip(reports2.iter()) {
assert_eq!(
r1.original_hash, r2.original_hash,
"{}: original hash mismatch across runs",
r1.fixture_name
);
assert_eq!(
r1.optimized_hash, r2.optimized_hash,
"{}: optimized hash mismatch across runs",
r1.fixture_name
);
assert_eq!(
r1.original_outcomes, r2.original_outcomes,
"{}: outcomes differ across runs",
r1.fixture_name
);
}
}
#[test]
fn e2e_pipeline_all_fixtures_pass() {
init_test();
let rules = [RewriteRule::DedupRaceJoin];
let reports = run_e2e_pipeline_all(RewritePolicy::conservative(), &rules);
assert!(
reports.len() >= 16,
"expected >= 16 E2E reports, got {}",
reports.len()
);
for report in &reports {
assert!(
report.all_ok(),
"E2E pipeline failed for fixture {}: cert_ok={}, steps_ok={}, outcomes_eq={}, \
extract_eq={}, rewrite_extract_eq={}, dynamic_eq={}",
report.fixture_name,
report.certificate_verified,
report.steps_verified,
report.outcomes_equivalent,
report.extraction_equivalent,
report.rewrite_extraction_equivalent,
report.dynamic_outcomes_equivalent,
);
}
}
#[test]
fn e2e_pipeline_deterministic_across_runs() {
init_test();
let rules = [RewriteRule::DedupRaceJoin];
let reports1 = run_e2e_pipeline_all(RewritePolicy::conservative(), &rules);
let reports2 = run_e2e_pipeline_all(RewritePolicy::conservative(), &rules);
assert_eq!(reports1.len(), reports2.len());
for (r1, r2) in reports1.iter().zip(reports2.iter()) {
assert_eq!(
r1.golden_fingerprint(),
r2.golden_fingerprint(),
"fixture {}: E2E golden fingerprint differs across runs",
r1.fixture_name
);
assert_eq!(
r1.certificate_fingerprint, r2.certificate_fingerprint,
"fixture {}: certificate fingerprint differs across runs",
r1.fixture_name
);
assert_eq!(
r1.original_cost.total(),
r2.original_cost.total(),
"fixture {}: original cost differs across runs",
r1.fixture_name
);
assert_eq!(
r1.optimized_cost.total(),
r2.optimized_cost.total(),
"fixture {}: optimized cost differs across runs",
r1.fixture_name
);
assert_eq!(
r1.original_trace_fingerprint, r2.original_trace_fingerprint,
"fixture {}: original trace fingerprint differs across runs",
r1.fixture_name
);
assert_eq!(
r1.optimized_trace_fingerprint, r2.optimized_trace_fingerprint,
"fixture {}: optimized trace fingerprint differs across runs",
r1.fixture_name
);
}
}
#[test]
fn e2e_pipeline_cost_never_increases() {
init_test();
let rules = [RewriteRule::DedupRaceJoin];
let reports = run_e2e_pipeline_all(RewritePolicy::conservative(), &rules);
for report in &reports {
assert!(
report.optimized_cost <= report.original_cost,
"fixture {}: cost increased from {} to {} after rewrite",
report.fixture_name,
report.original_cost.total(),
report.optimized_cost.total(),
);
}
}
#[test]
fn e2e_pipeline_dynamic_labels_populated() {
init_test();
let rules = [RewriteRule::DedupRaceJoin];
let reports = run_e2e_pipeline_all(RewritePolicy::conservative(), &rules);
let mut have_dynamic = 0;
for report in &reports {
if report.dynamic_original_labels.is_empty()
|| report.dynamic_optimized_labels.is_empty()
{
continue;
}
have_dynamic += 1;
}
assert!(have_dynamic > 0, "no fixtures produced dynamic labels");
}
#[test]
fn e2e_pipeline_trace_fingerprints_nonzero() {
init_test();
let rules = [RewriteRule::DedupRaceJoin];
let reports = run_e2e_pipeline_all(RewritePolicy::conservative(), &rules);
let mut have_traces = 0;
for report in &reports {
if report.original_trace_fingerprint == 0 || report.optimized_trace_fingerprint == 0 {
continue;
}
have_traces += 1;
}
assert!(have_traces > 0, "no fixtures produced trace fingerprints");
}
#[test]
fn e2e_pipeline_cost_delta_sane() {
init_test();
let rules = [RewriteRule::DedupRaceJoin];
let reports = run_e2e_pipeline_all(RewritePolicy::conservative(), &rules);
for report in &reports {
let delta = report.cost_delta();
if report.rewrite_count > 0 {
assert!(
delta >= 0,
"fixture {}: rewrite increased cost (delta={delta}, original={}, optimized={})",
report.fixture_name,
report.original_cost.total(),
report.optimized_cost.total(),
);
}
}
}
const ORACLE_SEEDS: [u64; 8] = [0, 1, 2, 3, 42, 99, 1000, u64::MAX];
#[test]
fn dynamic_lab_equivalence_all_fixtures_conservative() {
init_test();
let rules = [RewriteRule::DedupRaceJoin];
for fixture in all_fixtures() {
if fixture.name == "shared_non_leaf_associative" {
continue;
}
let report = run_lab_dynamic_equivalence(
fixture,
RewritePolicy::conservative(),
&rules,
&ORACLE_SEEDS,
);
assert!(
report.all_ok(),
"{}",
report
.failure_summary()
.unwrap_or_else(|| "unknown failure".into())
);
}
}
#[test]
fn dynamic_lab_equivalence_associative_policy() {
init_test();
let rules = [RewriteRule::DedupRaceJoin];
let fixtures = all_fixtures();
let fixture = fixtures
.into_iter()
.find(|f| f.name == "shared_non_leaf_associative")
.expect("fixture exists");
let report = run_lab_dynamic_equivalence(
fixture,
RewritePolicy::assume_all(),
&rules,
&ORACLE_SEEDS,
);
assert!(
report.all_ok(),
"{}",
report
.failure_summary()
.unwrap_or_else(|| "unknown failure".into())
);
}
#[test]
fn dynamic_lab_single_leaf_execution() {
init_test();
let mut dag = PlanDag::new();
let a = dag.leaf("alpha");
dag.set_root(a);
let result = execute_plan_in_lab(42, &dag);
assert_eq!(result.len(), 1);
assert!(result.contains("alpha"));
}
#[test]
fn dynamic_lab_join_collects_all_leaves() {
init_test();
let mut dag = PlanDag::new();
let a = dag.leaf("a");
let b = dag.leaf("b");
let c = dag.leaf("c");
let j = dag.join(vec![a, b, c]);
dag.set_root(j);
let result = execute_plan_in_lab(0, &dag);
assert_eq!(result.len(), 3);
assert!(result.contains("a"));
assert!(result.contains("b"));
assert!(result.contains("c"));
}
#[test]
fn dynamic_lab_race_returns_subset() {
init_test();
let mut dag = PlanDag::new();
let a = dag.leaf("a");
let b = dag.leaf("b");
let r = dag.race(vec![a, b]);
dag.set_root(r);
for seed in &ORACLE_SEEDS {
let result = execute_plan_in_lab(*seed, &dag);
assert_eq!(
result.len(),
1,
"seed {seed}: race should yield exactly one winner"
);
assert!(
result.contains("a") || result.contains("b"),
"seed {seed}: winner must be a or b"
);
}
}
#[test]
fn dag_has_fan_in_ignores_unreachable_rewrite_orphans() {
init_test();
let rules = [RewriteRule::DedupRaceJoin];
let mut fixture = simple_join_race_dedup();
assert!(
dag_has_fan_in(&fixture.dag),
"original fixture must have live fan-in"
);
let (_report, _cert) = fixture
.dag
.apply_rewrites_certified(RewritePolicy::conservative(), &rules);
assert!(
!dag_has_fan_in(&fixture.dag),
"rewritten reachable DAG should be a tree; orphaned pre-rewrite nodes must not count as live fan-in"
);
}
#[test]
fn dynamic_lab_executes_fan_in_fixture() {
init_test();
let fixture = simple_join_race_dedup();
assert!(
dag_has_fan_in(&fixture.dag),
"fixture should exercise fan-in"
);
for seed in &ORACLE_SEEDS {
let result = execute_plan_in_lab(*seed, &fixture.dag);
assert_eq!(
result.len(),
2,
"seed {seed}: dedup witness race should produce the shared leaf plus one branch winner"
);
assert!(
result.contains("shared"),
"seed {seed}: shared leaf must complete"
);
assert!(
result.contains("left") || result.contains("right"),
"seed {seed}: winner must include exactly one branch leaf"
);
}
}
#[test]
fn dynamic_lab_deterministic_same_seed() {
init_test();
for fixture in all_fixtures() {
let r1 = execute_plan_in_lab(42, &fixture.dag);
let r2 = execute_plan_in_lab(42, &fixture.dag);
assert_eq!(
r1, r2,
"fixture {}: same seed must produce identical results",
fixture.name
);
}
}
#[test]
fn dynamic_lab_timeout_passes_through() {
init_test();
let mut dag = PlanDag::new();
let a = dag.leaf("inner");
let t = dag.timeout(a, Duration::from_secs(10));
dag.set_root(t);
let result = execute_plan_in_lab(7, &dag);
assert_eq!(result.len(), 1);
assert!(result.contains("inner"));
}
#[test]
fn dynamic_lab_timeout_cancels_slow_child() {
init_test();
let mut dag = PlanDag::new();
let a = dag.leaf("inner");
let t = dag.timeout(a, Duration::from_secs(2));
dag.set_root(t);
let result = execute_plan_in_lab(7, &dag);
assert!(result.is_empty(), "short timeout should cancel slow child");
}
#[test]
fn dynamic_lab_nested_join_race() {
init_test();
let mut dag = PlanDag::new();
let a = dag.leaf("a");
let b = dag.leaf("b");
let c = dag.leaf("c");
let j_ab = dag.join(vec![a, b]);
let r = dag.race(vec![j_ab, c]);
dag.set_root(r);
for seed in &ORACLE_SEEDS {
let result = execute_plan_in_lab(*seed, &dag);
let is_join_winner = result.len() == 2 && result.contains("a") && result.contains("b");
let is_leaf_winner = result.len() == 1 && result.contains("c");
assert!(
is_join_winner || is_leaf_winner,
"seed {seed}: unexpected result {result:?}"
);
}
}
#[test]
fn dynamic_lab_report_fields_populated() {
init_test();
let rules = [RewriteRule::DedupRaceJoin];
let fixture = simple_join_race_dedup();
let report = run_lab_dynamic_equivalence(
fixture,
RewritePolicy::conservative(),
&rules,
&ORACLE_SEEDS,
);
assert_eq!(report.fixture_name, "simple_join_race_dedup");
assert_eq!(report.seeds.len(), ORACLE_SEEDS.len());
assert_eq!(report.per_seed_results.len(), ORACLE_SEEDS.len());
assert!(!report.original_outcome_universe.is_empty());
assert!(!report.optimized_outcome_universe.is_empty());
}
}