use crate::ir_inner::model::program::Program;
use rustc_hash::FxHashSet;
use std::sync::{Arc, LazyLock};
pub mod cost;
pub mod ctx;
pub mod derived_order;
pub mod diff_compile;
pub mod effect_lattice;
pub mod fact_substrate;
pub mod fusion_cert;
pub mod program_shape_facts;
pub mod shape_facts;
pub mod algebraic_rules;
pub mod dsl;
pub mod eqsat;
pub mod eqsat_gpu;
pub mod eqsat_toml;
pub mod expr_arena;
pub mod expr_arena_analysis;
pub mod hot_path_hints;
pub mod megakernel;
pub mod pass_catalog;
pub mod pass_explain;
pub mod pass_invariants;
pub mod pass_order;
pub mod pass_selection;
pub mod passes;
pub mod planar_batch;
pub mod pre_lowering;
pub mod program_soa;
mod rewrite;
pub mod rewrite_proof;
pub mod rewrite_proof_registry;
mod scheduler;
#[cfg(test)]
mod tests;
pub use ctx::{scheduling_error_to_diagnostic, AdapterCaps, AnalysisCache, PassCtx};
pub use derived_order::{
derive_pass_order, derive_registered_pass_order, DerivedPassEdge, DerivedPassEdgeKind,
DerivedPassNode, DerivedPassOrder,
};
pub use fusion_cert::FusionCertificate;
pub use pass_explain::{
explain_optimizer_report, explain_optimizer_report_with_catalog, CatalogLookupStatus,
PassExplanation, PassMetricDelta,
};
pub use pass_order::{
validate_registered_pass_order, validate_scheduled_pass_order, PassOrderValidation,
};
pub use pass_selection::{
registered_passes_for_profile_and_program, select_pass_metadata_for_program,
PassSelectionDecision, PassSelectionReason,
};
pub use planar_batch::{
default_planar_rewrite_batch_threshold, planar_rewrite_schedule_mask, RewriteBatch,
RewriteBatchCandidates, RewriteBatchItem, RewriteBatchPlan, RewriteCandidate,
};
pub use scheduler::{
schedule_passes, OptimizerRunReport, PassRunDecision, PassRunMetric, PassScheduler,
PassSchedulingError,
};
pub use vyre_macros::vyre_pass;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PassMetadata {
pub name: &'static str,
pub requires: &'static [&'static str],
pub invalidates: &'static [&'static str],
pub phase: PassPhase,
pub boundary_class: PassBoundaryClass,
pub requires_caps: &'static [&'static str],
pub preserves_abi: bool,
pub cost_model_family: CostModelFamily,
}
impl PassMetadata {
#[must_use]
pub const fn new(
name: &'static str,
requires: &'static [&'static str],
invalidates: &'static [&'static str],
) -> Self {
Self {
name,
requires,
invalidates,
phase: PassPhase::Unclassified,
boundary_class: PassBoundaryClass::Unknown,
requires_caps: &[],
preserves_abi: true,
cost_model_family: CostModelFamily::Unknown,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum PassPhase {
Unclassified,
Canonicalization,
ScalarAlgebra,
Loop,
Memory,
FusionCse,
Sync,
Specialization,
Cleanup,
Dataflow,
Megakernel,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum PassBoundaryClass {
Unknown,
AbiPreserving,
AbiChanging,
BackendAware,
RuntimeAware,
DomainSpecific,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum CostModelFamily {
Unknown,
Scalar,
Loop,
Memory,
Fusion,
Sync,
Dataflow,
Megakernel,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PassAnalysis {
pub should_run: bool,
}
impl PassAnalysis {
pub const RUN: Self = Self { should_run: true };
pub const SKIP: Self = Self { should_run: false };
}
#[derive(Debug, Clone, PartialEq)]
pub struct PassResult {
pub program: Program,
pub changed: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum RefusalReason {
CostIncrease {
delta: i64,
detail: &'static str,
},
EffectLatticeViolation {
producer: &'static str,
consumer: &'static str,
suggested_fix: &'static str,
},
WireContractViolation {
detail: &'static str,
},
Other {
detail: &'static str,
},
}
impl RefusalReason {
#[must_use]
pub fn kind(&self) -> &'static str {
match self {
Self::CostIncrease { .. } => "cost_increase",
Self::EffectLatticeViolation { .. } => "effect_lattice_violation",
Self::WireContractViolation { .. } => "wire_contract_violation",
Self::Other { .. } => "other",
}
}
}
impl std::fmt::Display for RefusalReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::CostIncrease { delta, detail } => {
write!(f, "cost_increase: delta={delta} reason={detail}")
}
Self::EffectLatticeViolation {
producer,
consumer,
suggested_fix,
} => write!(
f,
"effect_lattice_violation: producer={producer} consumer={consumer} fix={suggested_fix}"
),
Self::WireContractViolation { detail } => {
write!(f, "wire_contract_violation: {detail}")
}
Self::Other { detail } => write!(f, "other: {detail}"),
}
}
}
impl PassResult {
#[must_use]
#[inline]
pub fn from_programs(before: &Program, program: Program) -> Self {
let changed = before != &program;
Self { program, changed }
}
#[must_use]
#[inline]
pub fn unchanged(program: Program) -> Self {
Self {
program,
changed: false,
}
}
}
#[derive(Debug)]
pub struct ProgramPassRegistration {
pub metadata: PassMetadata,
pub factory: fn() -> Box<dyn ProgramPass>,
}
inventory::collect!(ProgramPassRegistration);
pub(crate) mod private {
pub trait Sealed {}
}
pub trait ProgramPass: private::Sealed + Send + Sync {
fn metadata(&self) -> PassMetadata;
fn preserves(&self) -> &'static [&'static str] {
&[]
}
fn pass_id(&self) -> &'static str {
self.metadata().name
}
fn analyze(&self, program: &Program) -> PassAnalysis;
fn transform(&self, program: Program) -> PassResult;
fn supports_planar_batching(&self) -> bool {
false
}
fn rewrite_candidates(&self, _program: &Program) -> RewriteBatchCandidates {
RewriteBatchCandidates::empty()
}
fn apply_rewrite_batch(&self, program: Program, _batch: &RewriteBatch) -> PassResult {
self.transform(program)
}
fn batch_apply(&self, program: Program) -> PassResult {
if !self.supports_planar_batching() {
return self.transform(program);
}
let candidates = self.rewrite_candidates(&program);
if candidates.is_empty() {
return PassResult::unchanged(program);
}
if !candidates.should_batch() {
return self.transform(program);
}
let plan = candidates.plan();
if !plan.has_batches() {
return PassResult::unchanged(program);
}
let mut changed = false;
let mut program = program;
for batch in plan.batches() {
let result = self.apply_rewrite_batch(program, batch);
changed |= result.changed;
program = result.program;
}
PassResult { program, changed }
}
fn try_transform(&self, program: Program) -> Result<PassResult, RefusalReason> {
Ok(self.transform(program))
}
fn try_batch_apply(&self, program: Program) -> Result<PassResult, RefusalReason> {
self.try_transform(program)
}
fn allowed_effect_additions(&self) -> crate::lower::effects::ProgramEffects {
crate::lower::effects::ProgramEffects::empty()
}
fn fingerprint(&self, program: &Program) -> u64;
}
pub struct ProgramPassKind(Box<dyn ProgramPass>);
impl ProgramPassKind {
#[must_use]
#[inline]
pub fn new<P: ProgramPass + 'static>(pass: P) -> Self {
Self(Box::new(pass))
}
#[must_use]
#[inline]
pub fn from_boxed(pass: Box<dyn ProgramPass>) -> Self {
Self(pass)
}
#[must_use]
#[inline]
pub fn metadata(&self) -> PassMetadata {
self.0.metadata()
}
#[must_use]
#[inline]
pub fn pass_id(&self) -> &'static str {
self.0.pass_id()
}
#[must_use]
#[inline]
pub fn analyze(&self, program: &Program) -> PassAnalysis {
self.0.analyze(program)
}
#[must_use]
#[inline]
pub fn transform(&self, program: Program) -> PassResult {
self.0.transform(program)
}
#[must_use]
#[inline]
pub fn batch_apply(&self, program: Program) -> PassResult {
self.0.batch_apply(program)
}
#[inline]
pub fn try_transform(&self, program: Program) -> Result<PassResult, RefusalReason> {
self.0.try_transform(program)
}
#[inline]
pub fn try_batch_apply(&self, program: Program) -> Result<PassResult, RefusalReason> {
self.0.try_batch_apply(program)
}
#[must_use]
#[inline]
pub fn preserves(&self) -> &'static [&'static str] {
self.0.preserves()
}
#[must_use]
#[inline]
pub fn allowed_effect_additions(&self) -> crate::lower::effects::ProgramEffects {
self.0.allowed_effect_additions()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum OptimizerProfile {
Release,
Dataflow,
Megakernel,
}
impl OptimizerProfile {
#[must_use]
pub fn accepts(self, metadata: PassMetadata) -> bool {
match self {
Self::Release => {
metadata.preserves_abi
&& metadata.boundary_class == PassBoundaryClass::AbiPreserving
&& metadata.requires_caps.is_empty()
&& !matches!(
metadata.phase,
PassPhase::Dataflow | PassPhase::Megakernel | PassPhase::Unclassified
)
}
Self::Dataflow => {
metadata.preserves_abi
&& matches!(
metadata.boundary_class,
PassBoundaryClass::AbiPreserving | PassBoundaryClass::DomainSpecific
)
&& metadata.phase == PassPhase::Dataflow
}
Self::Megakernel => {
metadata.preserves_abi
&& matches!(
metadata.boundary_class,
PassBoundaryClass::AbiPreserving | PassBoundaryClass::RuntimeAware
)
&& metadata.phase == PassPhase::Megakernel
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
#[non_exhaustive]
pub enum OptimizerError {
#[error(
"optimizer did not reach a fixpoint after {max_iterations} iterations. Fix: inspect pass `{last_pass}` for oscillating rewrites or raise the cap only with a convergence certificate."
)]
MaxIterations {
max_iterations: usize,
last_pass: &'static str,
},
#[error(
"optimizer pass `{pass}` requires `{missing}` but no prior pass provides it. Fix: register the required analysis pass or remove the stale requirement."
)]
UnsatisfiedRequirement {
pass: &'static str,
missing: &'static str,
},
#[error("{0}")]
Scheduling(#[from] PassSchedulingError),
#[error(
"pre-lowering phase {phase} did not converge after {max} iterations. Fix: inspect the phase for oscillating rewrites or raise the cap only with a convergence certificate."
)]
PreLoweringIterationLimit {
phase: u32,
max: usize,
},
}
pub fn registered_passes() -> Result<Vec<ProgramPassKind>, OptimizerError> {
let registrations = registered_pass_registrations()?;
let mut passes = Vec::with_capacity(registrations.len());
for registration in registrations.iter() {
passes.push(ProgramPassKind::from_boxed((registration.factory)()));
}
Ok(passes)
}
pub fn registered_passes_for_profile(
profile: OptimizerProfile,
) -> Result<Vec<ProgramPassKind>, OptimizerError> {
let registrations = registered_pass_registrations()?;
let mut passes = Vec::with_capacity(registrations.len());
for registration in registrations.iter() {
if profile.accepts(registration.metadata) {
passes.push(ProgramPassKind::from_boxed((registration.factory)()));
}
}
Ok(passes)
}
pub fn registered_pass_metadata_for_profile(
profile: OptimizerProfile,
) -> Result<Vec<PassMetadata>, OptimizerError> {
Ok(registered_pass_registrations()?
.iter()
.map(|registration| registration.metadata)
.filter(|&metadata| profile.accepts(metadata))
.collect())
}
#[must_use]
pub fn registered_pass_registrations(
) -> Result<Arc<[&'static ProgramPassRegistration]>, OptimizerError> {
static SCHEDULED: LazyLock<
Result<Arc<[&'static ProgramPassRegistration]>, PassSchedulingError>,
> = LazyLock::new(|| {
let registrations: Vec<&'static ProgramPassRegistration> =
inventory::iter::<ProgramPassRegistration>().collect();
schedule_passes(®istrations).map(|scheduled| scheduled.into_boxed_slice().into())
});
match &*SCHEDULED {
Ok(registrations) => Ok(Arc::clone(registrations)),
Err(error) => Err(OptimizerError::from(error.clone())),
}
}
pub fn optimize(program: Program) -> Result<Program, OptimizerError> {
static DEFAULT_SCHEDULER: LazyLock<Result<PassScheduler, OptimizerError>> =
LazyLock::new(PassScheduler::try_default);
match &*DEFAULT_SCHEDULER {
Ok(scheduler) => scheduler.run(program),
Err(err) => Err(err.clone()),
}
}
pub fn optimize_with_hot_path_hints(
program: Program,
profile: OptimizerProfile,
hints: &hot_path_hints::HotPathHints,
) -> Result<Program, OptimizerError> {
let passes =
pass_selection::registered_passes_for_profile_and_program(profile, &program, hints)?;
PassScheduler::try_with_passes(passes)?.run(program)
}
#[must_use]
pub fn pipeline_fingerprint_bytes(program: &Program) -> [u8; 32] {
program.fingerprint()
}
#[must_use]
pub fn fingerprint_program(program: &Program) -> u64 {
let first8 = pipeline_fingerprint_bytes(program);
u64::from_le_bytes([
first8[0], first8[1], first8[2], first8[3], first8[4], first8[5], first8[6], first8[7],
])
}
#[inline]
fn requirements_satisfied(metadata: PassMetadata, available: &FxHashSet<&'static str>) -> bool {
metadata
.requires
.iter()
.all(|requirement| available.contains(requirement))
}
#[cfg(test)]
mod framework_tests;