use crate::ir_inner::model::program::Program;
use rustc_hash::FxHashSet;
use std::sync::{Arc, LazyLock};
pub mod cost;
pub mod ctx;
pub mod diff_compile;
pub mod effect_lattice;
pub mod fact_substrate;
pub mod fusion_cert;
pub mod program_shape_facts;
pub mod shape_facts;
pub mod dsl;
pub mod eqsat;
pub mod eqsat_gpu;
pub mod eqsat_toml;
pub mod expr_arena;
pub mod hot_path_hints;
pub mod megakernel;
pub mod pass_invariants;
pub mod passes;
pub mod pre_lowering;
pub mod program_soa;
mod rewrite;
pub mod rewrite_proof;
pub mod rewrite_proof_registry;
mod scheduler;
#[cfg(test)]
mod tests;
pub use ctx::{scheduling_error_to_diagnostic, AdapterCaps, AnalysisCache, PassCtx};
pub use fusion_cert::FusionCertificate;
pub use scheduler::{
schedule_passes, OptimizerRunReport, PassRunMetric, PassScheduler, PassSchedulingError,
};
pub use vyre_macros::vyre_pass;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PassMetadata {
pub name: &'static str,
pub requires: &'static [&'static str],
pub invalidates: &'static [&'static str],
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PassAnalysis {
pub should_run: bool,
}
impl PassAnalysis {
pub const RUN: Self = Self { should_run: true };
pub const SKIP: Self = Self { should_run: false };
}
#[derive(Debug, Clone, PartialEq)]
pub struct PassResult {
pub program: Program,
pub changed: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum RefusalReason {
CostIncrease {
delta: i64,
detail: &'static str,
},
EffectLatticeViolation {
producer: &'static str,
consumer: &'static str,
suggested_fix: &'static str,
},
WireContractViolation {
detail: &'static str,
},
Other {
detail: &'static str,
},
}
impl RefusalReason {
#[must_use]
pub fn kind(&self) -> &'static str {
match self {
Self::CostIncrease { .. } => "cost_increase",
Self::EffectLatticeViolation { .. } => "effect_lattice_violation",
Self::WireContractViolation { .. } => "wire_contract_violation",
Self::Other { .. } => "other",
}
}
}
impl std::fmt::Display for RefusalReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::CostIncrease { delta, detail } => {
write!(f, "cost_increase: delta={delta} reason={detail}")
}
Self::EffectLatticeViolation {
producer,
consumer,
suggested_fix,
} => write!(
f,
"effect_lattice_violation: producer={producer} consumer={consumer} fix={suggested_fix}"
),
Self::WireContractViolation { detail } => {
write!(f, "wire_contract_violation: {detail}")
}
Self::Other { detail } => write!(f, "other: {detail}"),
}
}
}
impl PassResult {
#[must_use]
#[inline]
pub fn from_programs(before: &Program, program: Program) -> Self {
let changed = before != &program;
Self { program, changed }
}
#[must_use]
#[inline]
pub fn unchanged(program: Program) -> Self {
Self {
program,
changed: false,
}
}
}
#[derive(Debug)]
pub struct ProgramPassRegistration {
pub metadata: PassMetadata,
pub factory: fn() -> Box<dyn ProgramPass>,
}
inventory::collect!(ProgramPassRegistration);
pub(crate) mod private {
pub trait Sealed {}
}
pub trait ProgramPass: private::Sealed + Send + Sync {
fn metadata(&self) -> PassMetadata;
fn preserves(&self) -> &'static [&'static str] {
&[]
}
fn pass_id(&self) -> &'static str {
self.metadata().name
}
fn analyze(&self, program: &Program) -> PassAnalysis;
fn transform(&self, program: Program) -> PassResult;
fn try_transform(&self, program: Program) -> Result<PassResult, RefusalReason> {
Ok(self.transform(program))
}
fn fingerprint(&self, program: &Program) -> u64;
}
pub struct ProgramPassKind(Box<dyn ProgramPass>);
impl ProgramPassKind {
#[must_use]
#[inline]
pub fn new<P: ProgramPass + 'static>(pass: P) -> Self {
Self(Box::new(pass))
}
#[must_use]
#[inline]
pub fn from_boxed(pass: Box<dyn ProgramPass>) -> Self {
Self(pass)
}
#[must_use]
#[inline]
pub fn metadata(&self) -> PassMetadata {
self.0.metadata()
}
#[must_use]
#[inline]
pub fn pass_id(&self) -> &'static str {
self.0.pass_id()
}
#[must_use]
#[inline]
pub fn analyze(&self, program: &Program) -> PassAnalysis {
self.0.analyze(program)
}
#[must_use]
#[inline]
pub fn transform(&self, program: Program) -> PassResult {
self.0.transform(program)
}
#[inline]
pub fn try_transform(&self, program: Program) -> Result<PassResult, RefusalReason> {
self.0.try_transform(program)
}
#[must_use]
#[inline]
pub fn preserves(&self) -> &'static [&'static str] {
self.0.preserves()
}
}
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
#[non_exhaustive]
pub enum OptimizerError {
#[error(
"optimizer did not reach a fixpoint after {max_iterations} iterations. Fix: inspect pass `{last_pass}` for oscillating rewrites or raise the cap only with a convergence certificate."
)]
MaxIterations {
max_iterations: usize,
last_pass: &'static str,
},
#[error(
"optimizer pass `{pass}` requires `{missing}` but no prior pass provides it. Fix: register the required analysis pass or remove the stale requirement."
)]
UnsatisfiedRequirement {
pass: &'static str,
missing: &'static str,
},
#[error("{0}")]
Scheduling(#[from] PassSchedulingError),
}
pub fn registered_passes() -> Result<Vec<ProgramPassKind>, OptimizerError> {
let registrations = registered_pass_registrations()?;
let mut passes = Vec::with_capacity(registrations.len());
for registration in registrations.iter() {
passes.push(ProgramPassKind::from_boxed((registration.factory)()));
}
Ok(passes)
}
#[must_use]
pub fn registered_pass_registrations(
) -> Result<Arc<[&'static ProgramPassRegistration]>, OptimizerError> {
static SCHEDULED: LazyLock<
Result<Arc<[&'static ProgramPassRegistration]>, PassSchedulingError>,
> = LazyLock::new(|| {
let registrations: Vec<&'static ProgramPassRegistration> =
inventory::iter::<ProgramPassRegistration>().collect();
schedule_passes(®istrations).map(|scheduled| scheduled.into_boxed_slice().into())
});
match &*SCHEDULED {
Ok(registrations) => Ok(Arc::clone(registrations)),
Err(error) => Err(OptimizerError::from(error.clone())),
}
}
pub fn optimize(program: Program) -> Result<Program, OptimizerError> {
PassScheduler::default().run(program)
}
#[must_use]
pub fn pipeline_fingerprint_bytes(program: &Program) -> [u8; 32] {
program.fingerprint()
}
#[must_use]
pub fn fingerprint_program(program: &Program) -> u64 {
let first8 = pipeline_fingerprint_bytes(program);
u64::from_le_bytes([
first8[0], first8[1], first8[2], first8[3], first8[4], first8[5], first8[6], first8[7],
])
}
#[inline]
fn requirements_satisfied(metadata: PassMetadata, available: &FxHashSet<&'static str>) -> bool {
metadata
.requires
.iter()
.all(|requirement| available.contains(requirement))
}
#[cfg(test)]
mod optimizer_framework_tests {
use super::*;
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
fn trivial_program() -> Program {
Program::wrapped(
vec![
BufferDecl::read("input", 0, DataType::U32).with_count(4),
BufferDecl::output("out", 1, DataType::U32).with_count(1),
],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::load("input", Expr::u32(0)),
)],
)
}
const _: () = assert!(PassAnalysis::RUN.should_run);
const _: () = assert!(!PassAnalysis::SKIP.should_run);
#[test]
fn pass_result_unchanged_reports_no_change() {
let p = trivial_program();
let result = PassResult::unchanged(p);
assert!(!result.changed);
}
#[test]
fn pass_result_from_programs_identical() {
let p = trivial_program();
let result = PassResult::from_programs(&p, p.clone());
assert!(!result.changed);
}
#[test]
fn pass_metadata_construction() {
let meta = PassMetadata {
name: "test_pass",
requires: &["dead_buffer_elim"],
invalidates: &["fusion"],
};
assert_eq!(meta.name, "test_pass");
assert_eq!(meta.requires.len(), 1);
assert_eq!(meta.invalidates.len(), 1);
}
#[test]
fn optimizer_error_max_iterations_display() {
let err = OptimizerError::MaxIterations {
max_iterations: 100,
last_pass: "const_fold",
};
let msg = err.to_string();
assert!(msg.contains("100"));
assert!(msg.contains("const_fold"));
}
#[test]
fn optimizer_error_unsatisfied_requirement_display() {
let err = OptimizerError::UnsatisfiedRequirement {
pass: "fusion",
missing: "dead_buffer_elim",
};
let msg = err.to_string();
assert!(msg.contains("fusion"));
assert!(msg.contains("dead_buffer_elim"));
}
#[test]
fn fingerprint_is_deterministic() {
let p = trivial_program();
let a = fingerprint_program(&p);
let b = fingerprint_program(&p);
assert_eq!(a, b);
}
#[test]
fn fingerprint_different_programs_differ() {
let p1 = trivial_program();
let p2 = Program::wrapped(
vec![
BufferDecl::read("input", 0, DataType::U32).with_count(4),
BufferDecl::output("out", 1, DataType::U32).with_count(1),
],
[1, 1, 1],
vec![Node::store("out", Expr::u32(0), Expr::u32(42))],
);
assert_ne!(fingerprint_program(&p1), fingerprint_program(&p2));
}
#[test]
fn requirements_satisfied_empty_requires() {
let meta = PassMetadata {
name: "trivial",
requires: &[],
invalidates: &[],
};
let available = FxHashSet::default();
assert!(requirements_satisfied(meta, &available));
}
#[test]
fn requirements_satisfied_missing_dep() {
let meta = PassMetadata {
name: "needs_stuff",
requires: &["missing"],
invalidates: &[],
};
let available = FxHashSet::default();
assert!(!requirements_satisfied(meta, &available));
}
#[test]
fn refusal_reason_kind_tags_are_stable() {
let cost = RefusalReason::CostIncrease {
delta: 17,
detail: "fusion would add 12 atomic ops",
};
assert_eq!(cost.kind(), "cost_increase");
let effect = RefusalReason::EffectLatticeViolation {
producer: "vyre-libs::dataflow::reaching",
consumer: "vyre-primitives::reduce::scan",
suggested_fix: "insert MemoryOrdering::GridSync between arms",
};
assert_eq!(effect.kind(), "effect_lattice_violation");
let wire = RefusalReason::WireContractViolation {
detail: "op_id drift detected: vyre-primitives::math::add became vyre::add",
};
assert_eq!(wire.kind(), "wire_contract_violation");
let other = RefusalReason::Other {
detail: "user-provided refusal",
};
assert_eq!(other.kind(), "other");
}
#[test]
fn refusal_reason_display_includes_payload() {
let cost = RefusalReason::CostIncrease {
delta: 42,
detail: "extra atomics",
};
let msg = cost.to_string();
assert!(msg.contains("cost_increase"));
assert!(msg.contains("42"));
assert!(msg.contains("extra atomics"));
let effect = RefusalReason::EffectLatticeViolation {
producer: "p",
consumer: "c",
suggested_fix: "barrier",
};
let msg = effect.to_string();
assert!(msg.contains("p"));
assert!(msg.contains("c"));
assert!(msg.contains("barrier"));
}
#[test]
fn try_transform_default_delegates_to_transform_for_every_builtin() {
let p = trivial_program();
let passes = registered_passes().expect(
"Fix: registered_passes should succeed; restore this invariant before continuing.",
);
for pass in passes {
let result = pass.try_transform(p.clone());
assert!(
result.is_ok(),
"built-in pass `{}` unexpectedly returned a refusal",
pass.metadata().name
);
}
}
#[test]
fn preserves_default_is_empty_for_every_builtin() {
let passes = registered_passes().expect(
"Fix: registered_passes should succeed; restore this invariant before continuing.",
);
for pass in passes {
assert!(
pass.preserves().is_empty(),
"built-in pass `{}` declared a preserves[] entry but the scheduler doesn't \
yet honor it; either wire the scheduler or remove the declaration",
pass.metadata().name
);
}
}
#[test]
fn registered_passes_includes_builtins() {
let passes = registered_passes().expect(
"Fix: registered_passes should succeed; restore this invariant before continuing.",
);
assert!(passes.len() >= 19, "at least 19 builtin passes");
let names: Vec<_> = passes.iter().map(|p| p.metadata().name).collect();
assert!(names.contains(&"autotune"));
assert!(names.contains(&"buffer_decl_sort"));
assert!(names.contains(&"canonicalize"));
assert!(names.contains(&"const_fold"));
assert!(names.contains(&"loop_redundant_bound_check_elide"));
assert!(names.contains(&"loop_trip_zero_eliminate"));
assert!(names.contains(&"if_constant_branch_eliminate"));
assert!(names.contains(&"empty_block_collapse"));
assert!(names.contains(&"noop_assign_eliminate"));
assert!(names.contains(&"region_promote_singleton_block"));
assert!(names.contains(&"decode_scan_fuse"));
assert!(names.contains(&"loop_unroll"));
assert!(names.contains(&"vectorization"));
assert!(names.contains(&"dead_buffer_elim"));
}
}