use crate::ir::{BinOp, BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use crate::optimizer::cost::CostCertificate;
use crate::optimizer::{registered_passes, ProgramPassKind};
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum PassInvariantFinding {
RegistryError {
detail: String,
},
CostMonotoneViolation {
pass: &'static str,
program: &'static str,
increased: String,
},
StructurallyInvalid {
pass: &'static str,
program: &'static str,
detail: String,
},
}
fn synthetic_corpus() -> Vec<(&'static str, Program)> {
vec![
(
"trivial",
Program::wrapped(
vec![
BufferDecl::storage("out", 0, BufferAccess::ReadWrite, DataType::U32)
.with_count(1),
],
[1, 1, 1],
vec![Node::store("out", Expr::u32(0), Expr::u32(7))],
),
),
(
"arithmetic",
Program::wrapped(
vec![
BufferDecl::storage("out", 0, BufferAccess::ReadWrite, DataType::U32)
.with_count(1),
],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::add(Expr::u32(3), Expr::u32(4)),
)],
),
),
(
"divergent",
Program::wrapped(
vec![
BufferDecl::storage("out", 0, BufferAccess::ReadWrite, DataType::U32)
.with_count(1),
],
[256, 1, 1],
vec![Node::if_then(
Expr::BinOp {
op: BinOp::Eq,
left: Box::new(Expr::gid_x()),
right: Box::new(Expr::u32(0)),
},
vec![Node::store("out", Expr::u32(0), Expr::u32(1))],
)],
),
),
]
}
#[must_use]
pub fn audit_registered_passes() -> Vec<PassInvariantFinding> {
let passes = match registered_passes() {
Ok(passes) => passes,
Err(error) => {
return vec![PassInvariantFinding::RegistryError {
detail: error.to_string(),
}];
}
};
let corpus = synthetic_corpus();
let mut findings = Vec::new();
for pass in passes {
for (program_name, program) in &corpus {
findings.extend(audit_pass_on_program(&pass, program_name, program.clone()));
}
}
findings
}
fn audit_pass_on_program(
pass: &ProgramPassKind,
program_name: &'static str,
program: Program,
) -> Vec<PassInvariantFinding> {
let pre_cost = CostCertificate::for_program(&program);
let pass_name = pass.metadata().name;
let result = match pass.try_transform(program) {
Ok(result) => result,
Err(_refusal) => return Vec::new(),
};
let post_cost = CostCertificate::for_program(&result.program);
let mut findings = Vec::new();
if result.changed && !post_cost.dominates_or_equal(&pre_cost) {
let increased = post_cost.dimensions_increased_over(&pre_cost).join(",");
findings.push(PassInvariantFinding::CostMonotoneViolation {
pass: pass_name,
program: program_name,
increased,
});
}
let stats = result.program.stats();
if stats.node_count == 0 && result.changed {
findings.push(PassInvariantFinding::StructurallyInvalid {
pass: pass_name,
program: program_name,
detail: "rewrite produced zero-node program from non-empty input".into(),
});
}
findings
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn synthetic_corpus_has_three_programs_with_distinct_shapes() {
let corpus = synthetic_corpus();
assert_eq!(
corpus.len(),
3,
"corpus contract: trivial, arithmetic, divergent"
);
let names: Vec<&str> = corpus.iter().map(|(n, _)| *n).collect();
assert!(names.contains(&"trivial"));
assert!(names.contains(&"arithmetic"));
assert!(names.contains(&"divergent"));
}
#[test]
fn divergent_program_has_nonzero_divergence_score() {
let corpus = synthetic_corpus();
let divergent = corpus
.iter()
.find(|(n, _)| *n == "divergent")
.map(|(_, p)| p)
.expect("divergent program must be in corpus");
let cost = CostCertificate::for_program(divergent);
assert!(
cost.divergence_score >= 1,
"the divergent program must register divergence — without this, the verifier \
can't catch effect-lattice-related regressions"
);
}
#[test]
fn trivial_program_has_zero_divergence_score() {
let corpus = synthetic_corpus();
let trivial = corpus
.iter()
.find(|(n, _)| *n == "trivial")
.map(|(_, p)| p)
.expect("trivial must be in corpus");
let cost = CostCertificate::for_program(trivial);
assert_eq!(cost.divergence_score, 0);
}
#[test]
fn audit_runs_to_completion_without_panic() {
let _findings = audit_registered_passes();
}
const COST_INCREASE_EXEMPT: &[&str] = &["autotune"];
#[test]
fn audit_finds_zero_cost_monotone_violations_on_built_ins() {
let findings = audit_registered_passes();
let cost_violations: Vec<_> = findings
.iter()
.filter(|f| match f {
PassInvariantFinding::CostMonotoneViolation { pass, .. } => {
!COST_INCREASE_EXEMPT.contains(pass)
}
_ => false,
})
.collect();
assert!(
cost_violations.is_empty(),
"built-in passes must be cost-monotone-down on the synthetic corpus; \
non-exempt violations: {cost_violations:#?}"
);
}
#[test]
fn audit_finds_zero_structurally_invalid_outputs_on_built_ins() {
let findings = audit_registered_passes();
let invalid: Vec<_> = findings
.iter()
.filter(|f| matches!(f, PassInvariantFinding::StructurallyInvalid { .. }))
.collect();
assert!(
invalid.is_empty(),
"built-in passes must produce structurally-valid Programs; bad outputs: {invalid:#?}"
);
}
}