use super::analysis::SideConditionChecker;
use super::rewrite::{
RewritePolicy, RewriteReport, RewriteRule, RewriteStep, check_side_conditions,
};
use super::{PlanDag, PlanId, PlanNode};
const FNV_OFFSET: u64 = 0xcbf2_9ce4_8422_2325;
const FNV_PRIME: u64 = 0x0100_0000_01b3;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PlanHash(u64);
impl PlanHash {
#[must_use]
pub const fn value(self) -> u64 {
self.0
}
#[must_use]
pub fn of(dag: &PlanDag) -> Self {
let mut h = FNV_OFFSET;
h = fnv_u64(h, dag.nodes.len() as u64);
for node in &dag.nodes {
h = hash_node(h, node);
}
match dag.root {
Some(id) => {
h = fnv_u8(h, 1);
h = fnv_u64(h, id.index() as u64);
}
None => {
h = fnv_u8(h, 0);
}
}
Self(h)
}
}
fn fnv_u8(h: u64, byte: u8) -> u64 {
(h ^ u64::from(byte)).wrapping_mul(FNV_PRIME)
}
fn fnv_u64(mut h: u64, val: u64) -> u64 {
for &byte in &val.to_le_bytes() {
h = fnv_u8(h, byte);
}
h
}
fn fnv_bytes(mut h: u64, bytes: &[u8]) -> u64 {
for &byte in bytes {
h = fnv_u8(h, byte);
}
h
}
fn hash_node(mut h: u64, node: &PlanNode) -> u64 {
match node {
PlanNode::Leaf { label } => {
h = fnv_u8(h, 0); h = fnv_u64(h, label.len() as u64);
h = fnv_bytes(h, label.as_bytes());
}
PlanNode::Join { children } => {
h = fnv_u8(h, 1);
h = fnv_u64(h, children.len() as u64);
for child in children {
h = fnv_u64(h, child.index() as u64);
}
}
PlanNode::Race { children } => {
h = fnv_u8(h, 2);
h = fnv_u64(h, children.len() as u64);
for child in children {
h = fnv_u64(h, child.index() as u64);
}
}
PlanNode::Timeout { child, duration } => {
h = fnv_u8(h, 3);
h = fnv_u64(h, child.index() as u64);
h = fnv_u64(h, u64::try_from(duration.as_nanos()).unwrap_or(u64::MAX));
}
}
h
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CertificateVersion(u32);
impl CertificateVersion {
pub const CURRENT: Self = Self(1);
#[must_use]
pub const fn number(self) -> u32 {
self.0
}
#[cfg(test)]
pub(crate) const fn from_number(n: u32) -> Self {
Self(n)
}
}
#[derive(Debug, Clone)]
pub struct CertifiedStep {
pub rule: RewriteRule,
pub before: PlanId,
pub after: PlanId,
pub detail: String,
}
impl CertifiedStep {
fn from_rewrite_step(step: &RewriteStep) -> Self {
Self {
rule: step.rule,
before: step.before,
after: step.after,
detail: step.detail.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct RewriteCertificate {
pub version: CertificateVersion,
pub policy: RewritePolicy,
pub before_hash: PlanHash,
pub after_hash: PlanHash,
pub before_node_count: usize,
pub after_node_count: usize,
pub steps: Vec<CertifiedStep>,
}
impl RewriteCertificate {
#[must_use]
pub fn is_identity(&self) -> bool {
self.steps.is_empty() && self.before_hash == self.after_hash
}
#[must_use]
pub fn fingerprint(&self) -> u64 {
let mut h = FNV_OFFSET;
h = fnv_u64(h, u64::from(self.version.number()));
let policy_bits: u8 = pack_policy(self.policy);
h = fnv_u8(h, policy_bits);
h = fnv_u64(h, self.before_hash.value());
h = fnv_u64(h, self.after_hash.value());
h = fnv_u64(h, self.steps.len() as u64);
for step in &self.steps {
h = fnv_u8(h, step.rule as u8);
h = fnv_u64(h, step.before.index() as u64);
h = fnv_u64(h, step.after.index() as u64);
}
h
}
#[must_use]
pub fn minimize(&self) -> Self {
let mut minimized: Vec<CertifiedStep> = Vec::with_capacity(self.steps.len());
for step in &self.steps {
if step.before == step.after {
continue;
}
let is_inverse = minimized.last().is_some_and(|prev| {
prev.rule == step.rule
&& is_self_inverse(step.rule)
&& prev.before == step.after
&& prev.after == step.before
});
if is_inverse {
minimized.pop();
continue;
}
let is_dup = minimized.last().is_some_and(|prev| {
prev.rule == step.rule && prev.before == step.before && prev.after == step.after
});
if is_dup {
continue;
}
minimized.push(step.clone());
}
Self {
version: self.version,
policy: self.policy,
before_hash: self.before_hash,
after_hash: self.after_hash,
before_node_count: self.before_node_count,
after_node_count: self.after_node_count,
steps: minimized,
}
}
pub fn compact(&self) -> Result<CompactCertificate, CompactCertificateError> {
let before_node_count = u32::try_from(self.before_node_count).map_err(|_| {
CompactCertificateError::NodeCountOverflow {
field: "before_node_count",
value: self.before_node_count,
}
})?;
let after_node_count = u32::try_from(self.after_node_count).map_err(|_| {
CompactCertificateError::NodeCountOverflow {
field: "after_node_count",
value: self.after_node_count,
}
})?;
let mut steps = Vec::with_capacity(self.steps.len());
for (idx, step) in self.steps.iter().enumerate() {
steps.push(CompactStep::try_from_certified(idx, step)?);
}
Ok(CompactCertificate {
version: self.version,
policy_bits: pack_policy(self.policy),
before_hash: self.before_hash,
after_hash: self.after_hash,
before_node_count,
after_node_count,
steps,
})
}
}
fn is_self_inverse(rule: RewriteRule) -> bool {
matches!(rule, RewriteRule::JoinCommute | RewriteRule::RaceCommute)
}
fn pack_policy(policy: RewritePolicy) -> u8 {
u8::from(policy.associativity)
| (u8::from(policy.commutativity) << 1)
| (u8::from(policy.distributivity) << 2)
| (u8::from(policy.require_binary_joins) << 3)
| (u8::from(policy.timeout_simplification) << 4)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CompactStep {
pub rule: u8,
pub before: u32,
pub after: u32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CompactCertificateError {
NodeCountOverflow {
field: &'static str,
value: usize,
},
StepNodeOverflow {
step: usize,
field: &'static str,
value: usize,
},
}
impl CompactStep {
fn try_from_certified(
step_index: usize,
step: &CertifiedStep,
) -> Result<Self, CompactCertificateError> {
let before = u32::try_from(step.before.index()).map_err(|_| {
CompactCertificateError::StepNodeOverflow {
step: step_index,
field: "before",
value: step.before.index(),
}
})?;
let after = u32::try_from(step.after.index()).map_err(|_| {
CompactCertificateError::StepNodeOverflow {
step: step_index,
field: "after",
value: step.after.index(),
}
})?;
Ok(Self {
rule: step.rule as u8,
before,
after,
})
}
pub const WIRE_SIZE: usize = 9;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CompactCertificate {
pub version: CertificateVersion,
pub policy_bits: u8,
pub before_hash: PlanHash,
pub after_hash: PlanHash,
pub before_node_count: u32,
pub after_node_count: u32,
pub steps: Vec<CompactStep>,
}
impl CompactCertificate {
pub const HEADER_SIZE: usize = 33;
#[must_use]
pub fn byte_size_bound(&self) -> usize {
Self::HEADER_SIZE.saturating_add(self.steps.len().saturating_mul(CompactStep::WIRE_SIZE))
}
#[must_use]
pub fn is_within_linear_bound(&self) -> bool {
let node_bound = self.after_node_count.max(self.before_node_count) as usize;
self.steps.len() <= node_bound
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DagCostSnapshot {
pub node_count: usize,
pub joins: usize,
pub races: usize,
pub timeouts: usize,
pub depth: usize,
}
impl DagCostSnapshot {
#[must_use]
pub fn of(dag: &PlanDag) -> Self {
let mut joins = 0;
let mut races = 0;
let mut timeouts = 0;
for node in &dag.nodes {
match node {
PlanNode::Join { .. } => joins += 1,
PlanNode::Race { .. } => races += 1,
PlanNode::Timeout { .. } => timeouts += 1,
PlanNode::Leaf { .. } => {}
}
}
Self {
node_count: dag.nodes.len(),
joins,
races,
timeouts,
depth: dag_depth(dag),
}
}
}
fn dag_depth(dag: &PlanDag) -> usize {
fn depth_of(dag: &PlanDag, id: PlanId, memo: &mut Vec<Option<usize>>) -> usize {
let idx = id.index();
if idx >= memo.len() {
return 0;
}
if let Some(d) = memo[idx] {
return d;
}
let d = match dag.node(id) {
Some(PlanNode::Leaf { .. }) => 1,
Some(PlanNode::Join { children } | PlanNode::Race { children }) => {
let max_child = children
.iter()
.map(|c| depth_of(dag, *c, memo))
.max()
.unwrap_or(0);
max_child + 1
}
Some(PlanNode::Timeout { child, .. }) => depth_of(dag, *child, memo) + 1,
None => 0,
};
memo[idx] = Some(d);
d
}
if dag.nodes.is_empty() {
return 0;
}
let mut memo = vec![None; dag.nodes.len()];
dag.root.map_or(0, |root| depth_of(dag, root, &mut memo))
}
#[derive(Debug, Clone)]
pub struct ExplanationEntry {
pub step_index: usize,
pub law: &'static str,
pub description: String,
pub side_conditions: Vec<&'static str>,
}
#[derive(Debug, Clone)]
pub struct ExplanationLedger {
pub before: DagCostSnapshot,
pub after: DagCostSnapshot,
pub entries: Vec<ExplanationEntry>,
}
impl ExplanationLedger {
#[must_use]
pub fn render(&self) -> String {
use std::fmt::Write;
let mut out = String::new();
out.push_str("=== Plan Rewrite Explanation ===\n");
let _ = writeln!(
out,
"Before: {} nodes (J={} R={} T={}, depth={})",
self.before.node_count,
self.before.joins,
self.before.races,
self.before.timeouts,
self.before.depth,
);
let _ = writeln!(
out,
"After: {} nodes (J={} R={} T={}, depth={})",
self.after.node_count,
self.after.joins,
self.after.races,
self.after.timeouts,
self.after.depth,
);
let node_delta = self.after.node_count.cast_signed() - self.before.node_count.cast_signed();
let depth_delta = self.after.depth.cast_signed() - self.before.depth.cast_signed();
let _ = writeln!(out, "Delta: nodes={node_delta:+}, depth={depth_delta:+}");
let _ = writeln!(out, "Steps: {}", self.entries.len());
for entry in &self.entries {
let _ = writeln!(
out,
"\n [{}] {}: {}",
entry.step_index, entry.law, entry.description,
);
for cond in &entry.side_conditions {
let _ = writeln!(out, " condition: {cond}");
}
}
out
}
}
impl RewriteCertificate {
#[must_use]
pub fn explain(&self, before_dag: &PlanDag, after_dag: &PlanDag) -> ExplanationLedger {
let before = DagCostSnapshot::of(before_dag);
let after = DagCostSnapshot::of(after_dag);
let entries = self
.steps
.iter()
.enumerate()
.map(|(i, step)| explain_step(i, step, self.policy, after_dag))
.collect();
ExplanationLedger {
before,
after,
entries,
}
}
}
fn rule_law_name(rule: RewriteRule) -> &'static str {
match rule {
RewriteRule::JoinAssoc => "Join Associativity",
RewriteRule::RaceAssoc => "Race Associativity",
RewriteRule::JoinCommute => "Join Commutativity",
RewriteRule::RaceCommute => "Race Commutativity",
RewriteRule::TimeoutMin => "Timeout Minimization",
RewriteRule::DedupRaceJoin => "Race-Join Deduplication",
}
}
fn rule_side_conditions(rule: RewriteRule, policy: RewritePolicy) -> Vec<&'static str> {
match rule {
RewriteRule::JoinAssoc | RewriteRule::RaceAssoc => {
if policy.associativity {
vec!["associativity enabled"]
} else {
vec![]
}
}
RewriteRule::JoinCommute | RewriteRule::RaceCommute => {
let mut conds = Vec::new();
if policy.commutativity {
conds.push("commutativity enabled");
}
conds.push("children are pairwise independent");
conds
}
RewriteRule::TimeoutMin => vec!["nested timeout structure"],
RewriteRule::DedupRaceJoin => {
let mut conds = vec!["shared child across race branches"];
if policy.distributivity {
conds.push("distributivity enabled");
}
if policy.require_binary_joins {
conds.push("binary joins required (conservative)");
}
conds
}
}
}
fn describe_node_brief(dag: &PlanDag, id: PlanId) -> String {
match dag.node(id) {
Some(PlanNode::Leaf { label }) => format!("Leaf({label})"),
Some(PlanNode::Join { children }) => format!("Join[{}]", children.len()),
Some(PlanNode::Race { children }) => format!("Race[{}]", children.len()),
Some(PlanNode::Timeout { duration, .. }) => format!("Timeout({duration:?})"),
None => format!("?{}", id.index()),
}
}
fn explain_step(
idx: usize,
step: &CertifiedStep,
policy: RewritePolicy,
dag: &PlanDag,
) -> ExplanationEntry {
let before_desc = describe_node_brief(dag, step.before);
let after_desc = describe_node_brief(dag, step.after);
let description = format!(
"node {} ({}) -> node {} ({})",
step.before.index(),
before_desc,
step.after.index(),
after_desc,
);
ExplanationEntry {
step_index: idx,
law: rule_law_name(step.rule),
description,
side_conditions: rule_side_conditions(step.rule, policy),
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum VerifyError {
VersionMismatch {
expected: u32,
found: u32,
},
HashMismatch {
expected: u64,
actual: u64,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StepVerifyError {
MissingBeforeNode {
step: usize,
node: PlanId,
},
MissingAfterNode {
step: usize,
node: PlanId,
},
InvalidBeforeShape {
step: usize,
expected: &'static str,
},
InvalidAfterShape {
step: usize,
expected: &'static str,
},
SideConditionViolated {
step: usize,
condition: String,
},
}
pub fn verify_steps(cert: &RewriteCertificate, dag: &PlanDag) -> Result<(), StepVerifyError> {
for (idx, step) in cert.steps.iter().enumerate() {
verify_single_step(idx, step, cert.policy, dag)?;
}
Ok(())
}
fn verify_single_step(
idx: usize,
step: &CertifiedStep,
policy: RewritePolicy,
dag: &PlanDag,
) -> Result<(), StepVerifyError> {
match step.rule {
RewriteRule::JoinAssoc => verify_join_assoc_result(idx, step, policy, dag),
RewriteRule::RaceAssoc => verify_race_assoc_result(idx, step, policy, dag),
RewriteRule::JoinCommute => verify_join_commute_result(idx, step, policy, dag),
RewriteRule::RaceCommute => verify_race_commute_result(idx, step, policy, dag),
RewriteRule::TimeoutMin => verify_timeout_min_result(idx, step, policy, dag),
RewriteRule::DedupRaceJoin => verify_dedup_race_join_result(idx, step, policy, dag),
}
}
fn verify_side_conditions(
idx: usize,
step: &CertifiedStep,
policy: RewritePolicy,
dag: &PlanDag,
) -> Result<(), StepVerifyError> {
let checker = SideConditionChecker::new(dag);
if let Err(condition) =
check_side_conditions(step.rule, policy, &checker, dag, step.before, step.after)
{
return Err(StepVerifyError::SideConditionViolated {
step: idx,
condition,
});
}
Ok(())
}
fn verify_join_assoc_result(
idx: usize,
step: &CertifiedStep,
policy: RewritePolicy,
dag: &PlanDag,
) -> Result<(), StepVerifyError> {
let before = dag
.node(step.before)
.ok_or(StepVerifyError::MissingBeforeNode {
step: idx,
node: step.before,
})?;
let PlanNode::Join { children } = before else {
return Err(StepVerifyError::InvalidBeforeShape {
step: idx,
expected: "Join with at least one nested Join child",
});
};
let mut expected = Vec::new();
let mut changed = false;
for child in children {
match dag.node(*child) {
Some(PlanNode::Join { children }) => {
expected.extend(children.iter().copied());
changed = true;
}
Some(_) => expected.push(*child),
None => {
return Err(StepVerifyError::InvalidBeforeShape {
step: idx,
expected: "Join children must exist",
});
}
}
}
if !changed {
return Err(StepVerifyError::InvalidBeforeShape {
step: idx,
expected: "Join with at least one nested Join child",
});
}
let after = dag
.node(step.after)
.ok_or(StepVerifyError::MissingAfterNode {
step: idx,
node: step.after,
})?;
let PlanNode::Join {
children: after_children,
} = after
else {
return Err(StepVerifyError::InvalidAfterShape {
step: idx,
expected: "Join after JoinAssoc",
});
};
if *after_children != expected {
return Err(StepVerifyError::InvalidAfterShape {
step: idx,
expected: "Flattened Join children",
});
}
verify_side_conditions(idx, step, policy, dag)
}
fn verify_race_assoc_result(
idx: usize,
step: &CertifiedStep,
policy: RewritePolicy,
dag: &PlanDag,
) -> Result<(), StepVerifyError> {
let before = dag
.node(step.before)
.ok_or(StepVerifyError::MissingBeforeNode {
step: idx,
node: step.before,
})?;
let PlanNode::Race { children } = before else {
return Err(StepVerifyError::InvalidBeforeShape {
step: idx,
expected: "Race with at least one nested Race child",
});
};
let mut expected = Vec::new();
let mut changed = false;
for child in children {
match dag.node(*child) {
Some(PlanNode::Race { children }) => {
expected.extend(children.iter().copied());
changed = true;
}
Some(_) => expected.push(*child),
None => {
return Err(StepVerifyError::InvalidBeforeShape {
step: idx,
expected: "Race children must exist",
});
}
}
}
if !changed {
return Err(StepVerifyError::InvalidBeforeShape {
step: idx,
expected: "Race with at least one nested Race child",
});
}
let after = dag
.node(step.after)
.ok_or(StepVerifyError::MissingAfterNode {
step: idx,
node: step.after,
})?;
let PlanNode::Race {
children: after_children,
} = after
else {
return Err(StepVerifyError::InvalidAfterShape {
step: idx,
expected: "Race after RaceAssoc",
});
};
if *after_children != expected {
return Err(StepVerifyError::InvalidAfterShape {
step: idx,
expected: "Flattened Race children",
});
}
verify_side_conditions(idx, step, policy, dag)
}
fn verify_join_commute_result(
idx: usize,
step: &CertifiedStep,
policy: RewritePolicy,
dag: &PlanDag,
) -> Result<(), StepVerifyError> {
let before = dag
.node(step.before)
.ok_or(StepVerifyError::MissingBeforeNode {
step: idx,
node: step.before,
})?;
if !matches!(before, PlanNode::Join { .. }) {
return Err(StepVerifyError::InvalidBeforeShape {
step: idx,
expected: "Join before JoinCommute",
});
}
let after = dag
.node(step.after)
.ok_or(StepVerifyError::MissingAfterNode {
step: idx,
node: step.after,
})?;
if !matches!(after, PlanNode::Join { .. }) {
return Err(StepVerifyError::InvalidAfterShape {
step: idx,
expected: "Join after JoinCommute",
});
}
verify_side_conditions(idx, step, policy, dag)
}
fn verify_race_commute_result(
idx: usize,
step: &CertifiedStep,
policy: RewritePolicy,
dag: &PlanDag,
) -> Result<(), StepVerifyError> {
let before = dag
.node(step.before)
.ok_or(StepVerifyError::MissingBeforeNode {
step: idx,
node: step.before,
})?;
if !matches!(before, PlanNode::Race { .. }) {
return Err(StepVerifyError::InvalidBeforeShape {
step: idx,
expected: "Race before RaceCommute",
});
}
let after = dag
.node(step.after)
.ok_or(StepVerifyError::MissingAfterNode {
step: idx,
node: step.after,
})?;
if !matches!(after, PlanNode::Race { .. }) {
return Err(StepVerifyError::InvalidAfterShape {
step: idx,
expected: "Race after RaceCommute",
});
}
verify_side_conditions(idx, step, policy, dag)
}
fn verify_timeout_min_result(
idx: usize,
step: &CertifiedStep,
policy: RewritePolicy,
dag: &PlanDag,
) -> Result<(), StepVerifyError> {
let before = dag
.node(step.before)
.ok_or(StepVerifyError::MissingBeforeNode {
step: idx,
node: step.before,
})?;
let PlanNode::Timeout { child, duration } = before else {
return Err(StepVerifyError::InvalidBeforeShape {
step: idx,
expected: "Timeout wrapping a Timeout child",
});
};
let PlanNode::Timeout {
child: inner_child,
duration: inner_duration,
} = dag
.node(*child)
.ok_or(StepVerifyError::InvalidBeforeShape {
step: idx,
expected: "Timeout wrapping a Timeout child",
})?
else {
return Err(StepVerifyError::InvalidBeforeShape {
step: idx,
expected: "Timeout wrapping a Timeout child",
});
};
let expected_duration = if *duration <= *inner_duration {
*duration
} else {
*inner_duration
};
let after = dag
.node(step.after)
.ok_or(StepVerifyError::MissingAfterNode {
step: idx,
node: step.after,
})?;
let PlanNode::Timeout {
child: after_child,
duration: after_duration,
} = after
else {
return Err(StepVerifyError::InvalidAfterShape {
step: idx,
expected: "Timeout after TimeoutMin",
});
};
if *after_child != *inner_child || *after_duration != expected_duration {
return Err(StepVerifyError::InvalidAfterShape {
step: idx,
expected: "Timeout with min(d1,d2) over inner child",
});
}
verify_side_conditions(idx, step, policy, dag)
}
#[allow(clippy::too_many_lines)]
fn verify_dedup_race_join_result(
idx: usize,
step: &CertifiedStep,
policy: RewritePolicy,
dag: &PlanDag,
) -> Result<(), StepVerifyError> {
let after_node = dag
.node(step.after)
.ok_or(StepVerifyError::MissingAfterNode {
step: idx,
node: step.after,
})?;
let PlanNode::Join {
children: after_children,
} = after_node
else {
return Err(StepVerifyError::InvalidAfterShape {
step: idx,
expected: "Join node after DedupRaceJoin",
});
};
if after_children.len() != 2 {
return Err(StepVerifyError::InvalidAfterShape {
step: idx,
expected: "Join with exactly 2 children (shared + race)",
});
}
let before_node = dag
.node(step.before)
.ok_or(StepVerifyError::MissingBeforeNode {
step: idx,
node: step.before,
})?;
let PlanNode::Race { children } = before_node else {
return Err(StepVerifyError::InvalidBeforeShape {
step: idx,
expected: "Race of Join children before DedupRaceJoin",
});
};
if children.len() < 2 {
return Err(StepVerifyError::InvalidBeforeShape {
step: idx,
expected: "Race with >= 2 Join children before DedupRaceJoin",
});
}
let requires_binary_joins = policy.requires_binary_joins();
let allows_shared_non_leaf = policy.allows_shared_non_leaf();
if requires_binary_joins && children.len() != 2 {
return Err(StepVerifyError::InvalidBeforeShape {
step: idx,
expected: "Binary race required by Conservative policy",
});
}
let mut join_children = Vec::with_capacity(children.len());
for child in children {
match dag.node(*child) {
Some(PlanNode::Join { children }) => {
if requires_binary_joins && children.len() != 2 {
return Err(StepVerifyError::InvalidBeforeShape {
step: idx,
expected: "Binary joins required by Conservative policy",
});
}
join_children.push(children.clone());
}
_ => {
return Err(StepVerifyError::InvalidBeforeShape {
step: idx,
expected: "Race children must be Join nodes",
});
}
}
}
let mut intersection: std::collections::BTreeSet<PlanId> =
join_children[0].iter().copied().collect();
for join_nodes in join_children.iter().skip(1) {
let set: std::collections::BTreeSet<PlanId> = join_nodes.iter().copied().collect();
intersection.retain(|id| set.contains(id));
}
if intersection.len() != 1 {
return Err(StepVerifyError::InvalidBeforeShape {
step: idx,
expected: "Race joins must share exactly one child",
});
}
let shared = *intersection.iter().next().expect("shared");
if !allows_shared_non_leaf {
match dag.node(shared) {
Some(PlanNode::Leaf { .. }) => {}
_ => {
return Err(StepVerifyError::InvalidBeforeShape {
step: idx,
expected: "Shared child must be a Leaf under Conservative policy",
});
}
}
}
if !after_children.contains(&shared) {
return Err(StepVerifyError::InvalidAfterShape {
step: idx,
expected: "Join containing the shared child after DedupRaceJoin",
});
}
let has_race_child = after_children.iter().any(|child_id| {
dag.node(*child_id)
.is_some_and(|n| matches!(n, PlanNode::Race { .. }))
});
if !has_race_child {
return Err(StepVerifyError::InvalidAfterShape {
step: idx,
expected: "Join containing a Race child after DedupRaceJoin",
});
}
verify_side_conditions(idx, step, policy, dag)
}
pub fn verify(cert: &RewriteCertificate, dag: &PlanDag) -> Result<(), VerifyError> {
if cert.version != CertificateVersion::CURRENT {
return Err(VerifyError::VersionMismatch {
expected: CertificateVersion::CURRENT.number(),
found: cert.version.number(),
});
}
let actual = PlanHash::of(dag);
if cert.after_hash != actual {
return Err(VerifyError::HashMismatch {
expected: cert.after_hash.value(),
actual: actual.value(),
});
}
Ok(())
}
impl PlanDag {
pub fn apply_rewrites_certified(
&mut self,
policy: RewritePolicy,
rules: &[RewriteRule],
) -> (RewriteReport, RewriteCertificate) {
let before_hash = PlanHash::of(self);
let before_node_count = self.nodes.len();
let report = self.apply_rewrites(policy, rules);
let after_hash = PlanHash::of(self);
let after_node_count = self.nodes.len();
let steps = report
.steps()
.iter()
.map(CertifiedStep::from_rewrite_step)
.collect();
let cert = RewriteCertificate {
version: CertificateVersion::CURRENT,
policy,
before_hash,
after_hash,
before_node_count,
after_node_count,
steps,
};
(report, cert)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::init_test_logging;
use std::time::Duration;
fn init_test() {
init_test_logging();
}
fn dedup_race_join_ledger() -> ExplanationLedger {
let mut before_dag = PlanDag::new();
let shared = before_dag.leaf("shared");
let left = before_dag.leaf("left");
let right = before_dag.leaf("right");
let join_a = before_dag.join(vec![shared, left]);
let join_b = before_dag.join(vec![shared, right]);
let race = before_dag.race(vec![join_a, join_b]);
before_dag.set_root(race);
let mut after_dag = before_dag.clone();
let (_, cert) = after_dag
.apply_rewrites_certified(RewritePolicy::conservative(), &[RewriteRule::DedupRaceJoin]);
cert.explain(&before_dag, &after_dag)
}
#[test]
fn hash_deterministic_across_calls() {
init_test();
let mut dag = PlanDag::new();
let a = dag.leaf("a");
let b = dag.leaf("b");
let join = dag.join(vec![a, b]);
dag.set_root(join);
let h1 = PlanHash::of(&dag);
let h2 = PlanHash::of(&dag);
assert_eq!(h1, h2);
}
#[test]
fn different_dags_produce_different_hashes() {
init_test();
let mut dag1 = PlanDag::new();
let a = dag1.leaf("a");
let b = dag1.leaf("b");
let join = dag1.join(vec![a, b]);
dag1.set_root(join);
let mut dag2 = PlanDag::new();
let c = dag2.leaf("c");
let d = dag2.leaf("d");
let race = dag2.race(vec![c, d]);
dag2.set_root(race);
assert_ne!(PlanHash::of(&dag1), PlanHash::of(&dag2));
}
#[test]
fn child_order_matters() {
init_test();
let mut dag1 = PlanDag::new();
let a = dag1.leaf("a");
let b = dag1.leaf("b");
let join1 = dag1.join(vec![a, b]);
dag1.set_root(join1);
let mut dag2 = PlanDag::new();
let a2 = dag2.leaf("a");
let b2 = dag2.leaf("b");
let join2 = dag2.join(vec![b2, a2]);
dag2.set_root(join2);
assert_ne!(PlanHash::of(&dag1), PlanHash::of(&dag2));
}
#[test]
fn timeout_duration_affects_hash() {
init_test();
let mut dag1 = PlanDag::new();
let a = dag1.leaf("a");
let t1 = dag1.timeout(a, Duration::from_secs(1));
dag1.set_root(t1);
let mut dag2 = PlanDag::new();
let a2 = dag2.leaf("a");
let t2 = dag2.timeout(a2, Duration::from_secs(2));
dag2.set_root(t2);
assert_ne!(PlanHash::of(&dag1), PlanHash::of(&dag2));
}
#[test]
fn certified_rewrite_produces_valid_certificate() {
init_test();
let mut dag = PlanDag::new();
let shared = dag.leaf("shared");
let left = dag.leaf("left");
let right = dag.leaf("right");
let join_a = dag.join(vec![shared, left]);
let join_b = dag.join(vec![shared, right]);
let race = dag.race(vec![join_a, join_b]);
dag.set_root(race);
let (report, cert) = dag
.apply_rewrites_certified(RewritePolicy::conservative(), &[RewriteRule::DedupRaceJoin]);
assert_eq!(report.steps().len(), 1);
assert_eq!(cert.steps.len(), 1);
assert_eq!(cert.version, CertificateVersion::CURRENT);
assert_eq!(cert.policy, RewritePolicy::conservative());
assert_ne!(cert.before_hash, cert.after_hash);
assert!(!cert.is_identity());
assert!(verify(&cert, &dag).is_ok());
}
#[test]
fn identity_rewrite_produces_identity_certificate() {
init_test();
let mut dag = PlanDag::new();
let a = dag.leaf("a");
let b = dag.leaf("b");
let join = dag.join(vec![a, b]);
dag.set_root(join);
let (_report, cert) = dag
.apply_rewrites_certified(RewritePolicy::conservative(), &[RewriteRule::DedupRaceJoin]);
assert!(cert.is_identity());
assert!(verify(&cert, &dag).is_ok());
}
#[test]
fn verify_detects_hash_mismatch() {
init_test();
let mut dag = PlanDag::new();
let shared = dag.leaf("shared");
let left = dag.leaf("left");
let right = dag.leaf("right");
let join_a = dag.join(vec![shared, left]);
let join_b = dag.join(vec![shared, right]);
let race = dag.race(vec![join_a, join_b]);
dag.set_root(race);
let (_report, cert) = dag
.apply_rewrites_certified(RewritePolicy::conservative(), &[RewriteRule::DedupRaceJoin]);
dag.leaf("extra");
let result = verify(&cert, &dag);
assert!(result.is_err());
assert!(matches!(result, Err(VerifyError::HashMismatch { .. })));
}
#[test]
fn certificate_fingerprint_is_deterministic() {
init_test();
let mut dag = PlanDag::new();
let shared = dag.leaf("shared");
let left = dag.leaf("left");
let right = dag.leaf("right");
let join_a = dag.join(vec![shared, left]);
let join_b = dag.join(vec![shared, right]);
let race = dag.race(vec![join_a, join_b]);
dag.set_root(race);
let (_, cert) = dag
.apply_rewrites_certified(RewritePolicy::conservative(), &[RewriteRule::DedupRaceJoin]);
let fp1 = cert.fingerprint();
let fp2 = cert.fingerprint();
assert_eq!(fp1, fp2);
assert_ne!(fp1, 0);
}
#[test]
fn version_mismatch_detected() {
init_test();
let mut dag = PlanDag::new();
let a = dag.leaf("a");
dag.set_root(a);
let (_, mut cert) = dag
.apply_rewrites_certified(RewritePolicy::conservative(), &[RewriteRule::DedupRaceJoin]);
cert.version = CertificateVersion::from_number(99);
let result = verify(&cert, &dag);
assert!(matches!(result, Err(VerifyError::VersionMismatch { .. })));
}
#[test]
fn verify_steps_accepts_valid_rewrite() {
init_test();
let mut dag = PlanDag::new();
let shared = dag.leaf("shared");
let left = dag.leaf("left");
let right = dag.leaf("right");
let join_a = dag.join(vec![shared, left]);
let join_b = dag.join(vec![shared, right]);
let race = dag.race(vec![join_a, join_b]);
dag.set_root(race);
let (_, cert) = dag
.apply_rewrites_certified(RewritePolicy::conservative(), &[RewriteRule::DedupRaceJoin]);
assert!(verify_steps(&cert, &dag).is_ok());
}
#[test]
fn verify_steps_rejects_missing_after_node() {
init_test();
let mut dag = PlanDag::new();
let shared = dag.leaf("shared"); let left = dag.leaf("left"); let right = dag.leaf("right"); let join_a = dag.join(vec![shared, left]); let join_b = dag.join(vec![shared, right]); let race = dag.race(vec![join_a, join_b]); dag.set_root(race);
let cert = RewriteCertificate {
version: CertificateVersion::CURRENT,
policy: RewritePolicy::conservative(),
before_hash: PlanHash::of(&dag),
after_hash: PlanHash::of(&dag),
before_node_count: 6,
after_node_count: 6,
steps: vec![CertifiedStep {
rule: RewriteRule::DedupRaceJoin,
before: race,
after: PlanId::new(999), detail: "fake".to_string(),
}],
};
let result = verify_steps(&cert, &dag);
assert!(matches!(
result,
Err(StepVerifyError::MissingAfterNode { .. })
));
}
#[test]
fn verify_steps_rejects_wrong_after_shape() {
init_test();
let mut dag = PlanDag::new();
let shared = dag.leaf("shared"); let left = dag.leaf("left"); let right = dag.leaf("right"); let join_a = dag.join(vec![shared, left]); let join_b = dag.join(vec![shared, right]); let race = dag.race(vec![join_a, join_b]); dag.set_root(race);
let cert = RewriteCertificate {
version: CertificateVersion::CURRENT,
policy: RewritePolicy::conservative(),
before_hash: PlanHash::of(&dag),
after_hash: PlanHash::of(&dag),
before_node_count: 6,
after_node_count: 6,
steps: vec![CertifiedStep {
rule: RewriteRule::DedupRaceJoin,
before: race,
after: shared, detail: "fake".to_string(),
}],
};
let result = verify_steps(&cert, &dag);
assert!(matches!(
result,
Err(StepVerifyError::InvalidAfterShape { .. })
));
}
#[test]
fn verify_steps_identity_passes() {
init_test();
let mut dag = PlanDag::new();
let a = dag.leaf("a");
let b = dag.leaf("b");
let join = dag.join(vec![a, b]);
dag.set_root(join);
let (_, cert) = dag
.apply_rewrites_certified(RewritePolicy::conservative(), &[RewriteRule::DedupRaceJoin]);
assert!(cert.is_identity());
assert!(verify_steps(&cert, &dag).is_ok());
}
#[test]
fn minimize_removes_noop_steps() {
init_test();
let mut dag = PlanDag::new();
let a = dag.leaf("a");
dag.set_root(a);
let (_, base_cert) = dag
.apply_rewrites_certified(RewritePolicy::conservative(), &[RewriteRule::DedupRaceJoin]);
let mut cert = base_cert;
cert.steps.push(CertifiedStep {
rule: RewriteRule::JoinAssoc,
before: a,
after: a,
detail: "no-op".to_string(),
});
assert_eq!(cert.steps.len(), 1);
let minimized = cert.minimize();
assert!(minimized.steps.is_empty());
}
#[test]
fn minimize_removes_inverse_commute_pair() {
init_test();
let mut dag = PlanDag::new();
let a = dag.leaf("a");
let b = dag.leaf("b");
let join = dag.join(vec![a, b]);
dag.set_root(join);
let hash = PlanHash::of(&dag);
let cert = RewriteCertificate {
version: CertificateVersion::CURRENT,
policy: RewritePolicy::conservative(),
before_hash: hash,
after_hash: hash,
before_node_count: 3,
after_node_count: 3,
steps: vec![
CertifiedStep {
rule: RewriteRule::JoinCommute,
before: PlanId::new(2),
after: PlanId::new(3),
detail: "commute forward".to_string(),
},
CertifiedStep {
rule: RewriteRule::JoinCommute,
before: PlanId::new(3),
after: PlanId::new(2),
detail: "commute back".to_string(),
},
],
};
let minimized = cert.minimize();
assert!(minimized.steps.is_empty());
}
#[test]
fn minimize_removes_consecutive_duplicates() {
init_test();
let hash = PlanHash(0x1234);
let cert = RewriteCertificate {
version: CertificateVersion::CURRENT,
policy: RewritePolicy::conservative(),
before_hash: hash,
after_hash: hash,
before_node_count: 4,
after_node_count: 4,
steps: vec![
CertifiedStep {
rule: RewriteRule::JoinAssoc,
before: PlanId::new(0),
after: PlanId::new(1),
detail: "assoc".to_string(),
},
CertifiedStep {
rule: RewriteRule::JoinAssoc,
before: PlanId::new(0),
after: PlanId::new(1),
detail: "assoc dup".to_string(),
},
],
};
let minimized = cert.minimize();
assert_eq!(minimized.steps.len(), 1);
}
#[test]
fn minimize_preserves_non_redundant_steps() {
init_test();
let mut dag = PlanDag::new();
let shared = dag.leaf("shared");
let left = dag.leaf("left");
let right = dag.leaf("right");
let join_a = dag.join(vec![shared, left]);
let join_b = dag.join(vec![shared, right]);
let race = dag.race(vec![join_a, join_b]);
dag.set_root(race);
let (_, cert) = dag
.apply_rewrites_certified(RewritePolicy::conservative(), &[RewriteRule::DedupRaceJoin]);
let minimized = cert.minimize();
assert_eq!(minimized.steps.len(), cert.steps.len());
assert_eq!(minimized.before_hash, cert.before_hash);
assert_eq!(minimized.after_hash, cert.after_hash);
}
#[test]
fn compact_certificate_strips_details() {
init_test();
let mut dag = PlanDag::new();
let shared = dag.leaf("shared");
let left = dag.leaf("left");
let right = dag.leaf("right");
let join_a = dag.join(vec![shared, left]);
let join_b = dag.join(vec![shared, right]);
let race = dag.race(vec![join_a, join_b]);
dag.set_root(race);
let (_, cert) = dag
.apply_rewrites_certified(RewritePolicy::conservative(), &[RewriteRule::DedupRaceJoin]);
let compact = cert
.compact()
.expect("compact certificate fits u32 wire format");
assert_eq!(compact.steps.len(), cert.steps.len());
assert_eq!(compact.version, cert.version);
assert_eq!(compact.before_hash, cert.before_hash);
assert_eq!(compact.after_hash, cert.after_hash);
for (cs, fs) in compact.steps.iter().zip(cert.steps.iter()) {
assert_eq!(cs.rule, fs.rule as u8);
assert_eq!(cs.before, u32::try_from(fs.before.index()).unwrap());
assert_eq!(cs.after, u32::try_from(fs.after.index()).unwrap());
}
}
#[test]
fn compact_byte_size_bound_is_tight() {
init_test();
let mut dag = PlanDag::new();
let shared = dag.leaf("shared");
let left = dag.leaf("left");
let right = dag.leaf("right");
let join_a = dag.join(vec![shared, left]);
let join_b = dag.join(vec![shared, right]);
let race = dag.race(vec![join_a, join_b]);
dag.set_root(race);
let (_, cert) = dag
.apply_rewrites_certified(RewritePolicy::conservative(), &[RewriteRule::DedupRaceJoin]);
let compact = cert
.compact()
.expect("compact certificate fits u32 wire format");
let bound = compact.byte_size_bound();
assert_eq!(
bound,
CompactCertificate::HEADER_SIZE + CompactStep::WIRE_SIZE
);
assert_eq!(bound, 42);
}
#[test]
fn certificate_within_linear_bound() {
init_test();
let mut dag = PlanDag::new();
let shared = dag.leaf("shared");
let left = dag.leaf("left");
let right = dag.leaf("right");
let join_a = dag.join(vec![shared, left]);
let join_b = dag.join(vec![shared, right]);
let race = dag.race(vec![join_a, join_b]);
dag.set_root(race);
let (_, cert) = dag
.apply_rewrites_certified(RewritePolicy::conservative(), &[RewriteRule::DedupRaceJoin]);
let compact = cert
.compact()
.expect("compact certificate fits u32 wire format");
assert!(compact.is_within_linear_bound());
}
#[test]
fn identity_certificate_compact_is_minimal() {
init_test();
let mut dag = PlanDag::new();
let a = dag.leaf("a");
let b = dag.leaf("b");
let join = dag.join(vec![a, b]);
dag.set_root(join);
let (_, cert) = dag
.apply_rewrites_certified(RewritePolicy::conservative(), &[RewriteRule::DedupRaceJoin]);
assert!(cert.is_identity());
let compact = cert
.compact()
.expect("compact certificate fits u32 wire format");
assert!(compact.steps.is_empty());
assert_eq!(compact.byte_size_bound(), CompactCertificate::HEADER_SIZE);
assert!(compact.is_within_linear_bound());
}
#[test]
fn minimize_then_compact_reduces_size() {
init_test();
let hash = PlanHash(0xABCD);
let cert = RewriteCertificate {
version: CertificateVersion::CURRENT,
policy: RewritePolicy::conservative(),
before_hash: hash,
after_hash: hash,
before_node_count: 5,
after_node_count: 5,
steps: vec![
CertifiedStep {
rule: RewriteRule::RaceCommute,
before: PlanId::new(0),
after: PlanId::new(1),
detail: "commute".to_string(),
},
CertifiedStep {
rule: RewriteRule::RaceCommute,
before: PlanId::new(1),
after: PlanId::new(0),
detail: "un-commute".to_string(),
},
CertifiedStep {
rule: RewriteRule::JoinAssoc,
before: PlanId::new(2),
after: PlanId::new(3),
detail: "assoc".to_string(),
},
],
};
let raw_compact = cert
.compact()
.expect("compact certificate fits u32 wire format");
let minimized_compact = cert
.minimize()
.compact()
.expect("compact certificate fits u32 wire format");
assert_eq!(raw_compact.steps.len(), 3);
assert_eq!(minimized_compact.steps.len(), 1);
assert!(minimized_compact.byte_size_bound() < raw_compact.byte_size_bound());
}
#[test]
fn compact_rejects_node_count_overflow() {
init_test();
if usize::BITS <= u32::BITS {
return;
}
let cert = RewriteCertificate {
version: CertificateVersion::CURRENT,
policy: RewritePolicy::conservative(),
before_hash: PlanHash(0xABCD),
after_hash: PlanHash(0xABCD),
before_node_count: (u32::MAX as usize) + 1,
after_node_count: 1,
steps: Vec::new(),
};
let err = cert.compact().expect_err("overflow must be rejected");
assert_eq!(
err,
CompactCertificateError::NodeCountOverflow {
field: "before_node_count",
value: (u32::MAX as usize) + 1,
}
);
}
#[test]
fn compact_rejects_step_node_overflow() {
init_test();
if usize::BITS <= u32::BITS {
return;
}
let cert = RewriteCertificate {
version: CertificateVersion::CURRENT,
policy: RewritePolicy::conservative(),
before_hash: PlanHash(0x1234),
after_hash: PlanHash(0x1234),
before_node_count: 1,
after_node_count: 1,
steps: vec![CertifiedStep {
rule: RewriteRule::JoinAssoc,
before: PlanId::new((u32::MAX as usize) + 1),
after: PlanId::new(0),
detail: "overflow".to_string(),
}],
};
let err = cert.compact().expect_err("overflow must be rejected");
assert_eq!(
err,
CompactCertificateError::StepNodeOverflow {
step: 0,
field: "before",
value: (u32::MAX as usize) + 1,
}
);
}
#[test]
fn dag_cost_snapshot_counts_nodes() {
init_test();
let mut dag = PlanDag::new();
let a = dag.leaf("a");
let b = dag.leaf("b");
let join = dag.join(vec![a, b]);
let c = dag.leaf("c");
let race = dag.race(vec![join, c]);
dag.set_root(race);
let snap = DagCostSnapshot::of(&dag);
assert_eq!(snap.node_count, 5);
assert_eq!(snap.joins, 1);
assert_eq!(snap.races, 1);
assert_eq!(snap.timeouts, 0);
assert_eq!(snap.depth, 3); }
#[test]
fn dag_depth_handles_timeout() {
init_test();
let mut dag = PlanDag::new();
let a = dag.leaf("a");
let t = dag.timeout(a, Duration::from_secs(1));
dag.set_root(t);
let snap = DagCostSnapshot::of(&dag);
assert_eq!(snap.depth, 2); assert_eq!(snap.timeouts, 1);
}
#[test]
fn explain_produces_entries_for_each_step() {
init_test();
let mut before_dag = PlanDag::new();
let shared = before_dag.leaf("shared");
let left = before_dag.leaf("left");
let right = before_dag.leaf("right");
let join_a = before_dag.join(vec![shared, left]);
let join_b = before_dag.join(vec![shared, right]);
let race = before_dag.race(vec![join_a, join_b]);
before_dag.set_root(race);
let mut after_dag = before_dag.clone();
let (_, cert) = after_dag
.apply_rewrites_certified(RewritePolicy::conservative(), &[RewriteRule::DedupRaceJoin]);
let ledger = cert.explain(&before_dag, &after_dag);
assert_eq!(ledger.entries.len(), cert.steps.len());
assert_eq!(ledger.entries[0].law, "Race-Join Deduplication");
assert!(!ledger.entries[0].side_conditions.is_empty());
}
#[test]
fn explain_identity_is_empty() {
init_test();
let mut dag = PlanDag::new();
let a = dag.leaf("a");
let b = dag.leaf("b");
let join = dag.join(vec![a, b]);
dag.set_root(join);
let before_dag = dag.clone();
let (_, cert) = dag
.apply_rewrites_certified(RewritePolicy::conservative(), &[RewriteRule::DedupRaceJoin]);
let ledger = cert.explain(&before_dag, &dag);
assert!(ledger.entries.is_empty());
assert_eq!(ledger.before.node_count, ledger.after.node_count);
}
#[test]
fn explain_render_is_deterministic() {
init_test();
let ledger = dedup_race_join_ledger();
let r1 = ledger.render();
let r2 = ledger.render();
assert_eq!(r1, r2);
assert!(r1.contains("Plan Rewrite Explanation"));
assert!(r1.contains("Race-Join Deduplication"));
assert!(r1.contains("Before:"));
assert!(r1.contains("After:"));
assert!(r1.contains("Delta:"));
}
#[test]
fn explain_shows_cost_delta() {
init_test();
let ledger = dedup_race_join_ledger();
assert!(ledger.after.node_count >= ledger.before.node_count);
let rendered = ledger.render();
assert!(rendered.contains("nodes="));
assert!(rendered.contains("depth="));
}
#[test]
fn explain_render_snapshot_dedup_race_join() {
init_test();
let ledger = dedup_race_join_ledger();
insta::assert_snapshot!("plan_certificate_dedup_race_join_render", ledger.render());
}
#[test]
fn plan_hash_debug_clone_copy_eq_hash() {
use std::collections::HashSet;
let mut dag = PlanDag::new();
let a = dag.leaf("a");
dag.set_root(a);
let h = PlanHash::of(&dag);
let dbg = format!("{h:?}");
assert!(dbg.contains("PlanHash"));
let h2 = h;
assert_eq!(h, h2);
let h3 = h;
assert_eq!(h, h3);
let mut set = HashSet::new();
set.insert(h);
assert!(set.contains(&h));
}
#[test]
fn certificate_version_debug_clone_copy_eq() {
let v = CertificateVersion::CURRENT;
let dbg = format!("{v:?}");
assert!(dbg.contains("CertificateVersion"));
let v2 = v;
assert_eq!(v, v2);
let v3 = v;
assert_eq!(v, v3);
}
#[test]
fn compact_step_debug_clone_copy_eq() {
let s = CompactStep {
rule: 1,
before: 10,
after: 20,
};
let dbg = format!("{s:?}");
assert!(dbg.contains("CompactStep"));
let s2 = s;
assert_eq!(s, s2);
let s3 = s;
assert_eq!(s, s3);
}
}