pub mod beam;
pub mod config;
pub mod error;
pub mod heuristics;
pub mod kernel_info;
pub mod opts;
pub mod renderer;
pub mod scheduler;
pub mod tc;
pub mod types;
pub use beam::{BeamResult, beam_search, beam_search_cached, beam_search_with_timeout, clear_cache, replay_opts};
pub use config::{BeamConfig, HeuristicsConfig, OptStrategy, OptimizerConfig, TcOpt as TcOptLevel, TcSelect, TcUsage};
pub use error::OptError;
pub use heuristics::hand_coded_optimizations;
pub use kernel_info::KernelInfo;
pub use opts::apply_opt;
pub use renderer::{Renderer, TcOpt, TensorCore};
pub use scheduler::Scheduler;
#[cfg(test)]
pub use scheduler::clear_kernel_name_counts;
pub use types::{AxisType, Opt, OptArg, OptOps};
use crate::devectorize::{
Fp8DecompCtx, bool_storage_patterns, pm_float_decomp, pm_float_decomp_store, pm_reduce, pm_render,
pm_wmma_accumulate,
};
use crate::gpudims::pm_add_gpudims;
use crate::rangeify::patterns::{
pm_add_loads, pm_comparison_negations, pm_demorgan, pm_div_to_shr, pm_erf_decomposition, pm_fdiv_to_mul,
pm_fma_decomposition, pm_load_collapse, pm_mod_to_and, pm_mul_to_shl, pm_neg_from_mul, pm_shl_add_to_mulacc,
pm_threefry_decomp, rangeify_codegen_with_kernel_ctx,
};
use crate::rangeify::pm_add_buffers_local_patterns;
use crate::rangeify::transforms::{pm_flatten_range, pm_simplify_ranges, pm_split_ranges};
use crate::rewrite::graph_rewrite;
use crate::symbolic::patterns::{gep_pushing_patterns, sym, symbolic, symbolic_simple};
use std::sync::{Arc, LazyLock};
pub fn optimize_kernel(ast: Arc<morok_ir::UOp>, renderer: &Renderer) -> Arc<morok_ir::UOp> {
optimize_kernel_with_config(ast, renderer, &OptimizerConfig::from_env())
}
#[tracing::instrument(skip_all)]
pub fn apply_post_optimization(ast: Arc<morok_ir::UOp>) -> Arc<morok_ir::UOp> {
apply_post_optimization_with_renderer(ast, None)
}
#[tracing::instrument(skip_all)]
pub fn apply_post_optimization_with_renderer(
ast: Arc<morok_ir::UOp>,
renderer: Option<&Renderer>,
) -> Arc<morok_ir::UOp> {
let saved_metadata = ast.metadata_raw();
tracing::debug!(ast.initial = ast.tree(), node_count = ast.node_count(), "kernel initial");
let t_stage = std::time::Instant::now();
static POST_OPT_SYM: LazyLock<crate::TypedPatternMatcher> =
LazyLock::new(|| sym().clone() + crate::symbolic::patterns::pm_move_where_on_load());
let with_symbolic = graph_rewrite(&*POST_OPT_SYM, ast, &mut ());
tracing::debug!(
ast.optimized = with_symbolic.tree(),
node_count = with_symbolic.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"Stage 8: after post-opt symbolic"
);
let t_stage = std::time::Instant::now();
let expanded = crate::expand::pre_expand(&with_symbolic);
tracing::debug!(
ast.optimized = expanded.tree(),
node_count = expanded.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"Stage 9: after pre_expand"
);
let check_unroll_group = |label: &str, root: &Arc<morok_ir::UOp>| {
for node in root.toposort() {
if let morok_ir::Op::Unroll { src, unroll_axes, .. } = node.op()
&& matches!(src.op(), morok_ir::Op::Group { .. })
{
tracing::error!(id = node.id, axes = ?unroll_axes, stage = label, "UNROLL(GROUP) found!");
}
}
};
let t_stage = std::time::Instant::now();
let with_local_buffers = {
let mut buf_ctx = crate::rangeify::KernelContext::new();
static PM_LOCAL_BUF: LazyLock<crate::TypedPatternMatcher<crate::rangeify::KernelContext>> =
LazyLock::new(|| pm_add_buffers_local_patterns() + rangeify_codegen_with_kernel_ctx());
graph_rewrite(&*PM_LOCAL_BUF, expanded, &mut buf_ctx)
};
tracing::debug!(
ast.optimized = with_local_buffers.tree(),
node_count = with_local_buffers.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"Stage 10: after add local buffers"
);
if cfg!(debug_assertions) {
check_unroll_group("after_add_local_buffers", &with_local_buffers);
}
let t_stage = std::time::Instant::now();
static PM_REDUCE_COMBINED: LazyLock<crate::TypedPatternMatcher<crate::devectorize::ReduceContext>> =
LazyLock::new(|| pm_reduce() + pm_wmma_accumulate().with_context() + gep_pushing_patterns().with_context());
let mut reduce_ctx = crate::devectorize::ReduceContext::default();
let reduced = graph_rewrite(&*PM_REDUCE_COMBINED, with_local_buffers, &mut reduce_ctx);
tracing::debug!(
ast.optimized = reduced.tree(),
node_count = reduced.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"after pm_reduce"
);
if cfg!(debug_assertions) {
check_unroll_group("after_pm_reduce", &reduced);
}
let t_stage = std::time::Instant::now();
let with_gpudims = if let Some(ren) = renderer {
if ren.has_local { graph_rewrite(&pm_add_gpudims(), reduced, &mut ren.clone()) } else { reduced }
} else {
reduced
};
tracing::debug!(
ast.optimized = with_gpudims.tree(),
node_count = with_gpudims.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"after pm_add_gpudims"
);
if cfg!(debug_assertions) {
check_unroll_group("after_pm_add_gpudims", &with_gpudims);
}
let t_stage = std::time::Instant::now();
let with_loads = graph_rewrite(pm_add_loads(), with_gpudims, &mut ());
tracing::debug!(
ast.optimized = with_loads.tree(),
node_count = with_loads.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"after pm_add_loads"
);
if cfg!(debug_assertions) {
check_unroll_group("after_pm_add_loads", &with_loads);
for node in with_loads.toposort() {
if let morok_ir::Op::Unroll { src, unroll_axes, .. } = node.op() {
tracing::error!(
id = node.id,
src_op = src.op().as_ref(),
axes = ?unroll_axes,
"BEFORE devectorize: found UNROLL!"
);
}
if let morok_ir::Op::Contract { src, upcast_ranges, .. } = node.op() {
tracing::error!(
id = node.id,
src_op = src.op().as_ref(),
axes = ?upcast_ranges,
"BEFORE devectorize: found CONTRACT!"
);
}
}
}
let t_stage = std::time::Instant::now();
let devectorized = crate::devectorize::devectorize(&with_loads);
tracing::debug!(
ast.optimized = devectorized.tree(),
node_count = devectorized.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"after devectorize"
);
check_unroll_group("after_devectorize", &devectorized);
let t_stage = std::time::Instant::now();
static PM_LOWER_COMBINED: LazyLock<crate::TypedPatternMatcher> = LazyLock::new(|| {
crate::symbolic::pm_lower_index_dtype()
+ crate::devectorize::load_store_indexing_patterns()
+ gep_pushing_patterns()
});
let with_lowered_idx = graph_rewrite(&*PM_LOWER_COMBINED, devectorized, &mut ());
tracing::debug!(
ast.optimized = with_lowered_idx.tree(),
node_count = with_lowered_idx.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"after pm_lower_index_dtype"
);
check_unroll_group("after_pm_lower_index_dtype", &with_lowered_idx);
let t_stage = std::time::Instant::now();
static POST_INDEX_SYM: LazyLock<crate::TypedPatternMatcher> = LazyLock::new(|| symbolic().clone());
let with_lowered_idx = graph_rewrite(&*POST_INDEX_SYM, with_lowered_idx, &mut ());
tracing::debug!(
ast.optimized = with_lowered_idx.tree(),
node_count = with_lowered_idx.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"after post-index symbolic"
);
let t_stage = std::time::Instant::now();
static PM_FINAL: LazyLock<crate::TypedPatternMatcher> =
LazyLock::new(|| symbolic_simple() + get_late_rewrite_patterns() + pm_render());
let rendered = graph_rewrite(&*PM_FINAL, with_lowered_idx, &mut ());
tracing::debug!(
ast.optimized = rendered.tree(),
node_count = rendered.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"Stage 18-19: after pm_decomp + pm_render"
);
let t_merge = std::time::Instant::now();
let rendered = crate::devectorize::merge_sibling_ends(&rendered);
tracing::debug!(
ast.optimized = rendered.tree(),
node_count = rendered.node_count(),
elapsed_ms = t_merge.elapsed().as_millis() as u64,
"after merge_sibling_ends"
);
let t_stage = std::time::Instant::now();
let fp8_pm = pm_float_decomp();
let fp8_bpm = pm_float_decomp_store();
let mut fp8_decomposed = rendered;
for (fr, to) in [
(morok_dtype::ScalarDType::FP8E5M2, morok_dtype::ScalarDType::Float16),
(morok_dtype::ScalarDType::FP8E4M3, morok_dtype::ScalarDType::Float16),
] {
let mut ctx = Fp8DecompCtx { from: fr, to };
fp8_decomposed = morok_ir::rewrite::graph_rewrite_with_bpm(&fp8_pm, &fp8_bpm, fp8_decomposed, &mut ctx);
}
tracing::debug!(
ast.optimized = fp8_decomposed.tree(),
node_count = fp8_decomposed.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"after pm_float_decomp"
);
let t_stage = std::time::Instant::now();
let bs = graph_rewrite(bool_storage_patterns(), fp8_decomposed, &mut ());
tracing::debug!(
ast.optimized = bs.tree(),
node_count = bs.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"after bool_storage_pattern"
);
match saved_metadata {
Some(meta) => bs.with_metadata_raw(meta),
None => bs,
}
}
fn get_late_rewrite_patterns() -> &'static crate::TypedPatternMatcher {
static CACHED: LazyLock<crate::TypedPatternMatcher> = LazyLock::new(|| {
pm_fma_decomposition()
+ pm_erf_decomposition()
+ pm_mod_to_and()
+ pm_mul_to_shl()
+ pm_div_to_shr()
+ pm_fdiv_to_mul()
+ pm_neg_from_mul()
+ pm_demorgan()
+ pm_shl_add_to_mulacc()
+ pm_threefry_decomp()
+ pm_comparison_negations()
+ crate::symbolic::fast_division_patterns()
+ pm_mod_to_idiv()
});
&CACHED
}
fn pm_mod_to_idiv() -> &'static crate::TypedPatternMatcher {
crate::cached_patterns! {
Mod(x, d @const(d_val))
if x.dtype().is_int()
&& matches!(d_val.try_int(), Some(v) if v > 1 && !((v as u64).is_power_of_two()))
=> {
let div = x.idiv(d);
let mul = d.try_mul(&div).ok()?;
x.try_sub(&mul).ok()
},
}
}
#[tracing::instrument(skip_all)]
pub fn apply_pre_optimization(ast: Arc<morok_ir::UOp>) -> Arc<morok_ir::UOp> {
tracing::debug!(ast.initial = ast.tree(), node_count = ast.node_count(), "kernel initial");
use crate::rangeify::transforms::SplitRangesContext;
let t_stage = std::time::Instant::now();
use crate::rangeify::patterns::{movement_op_patterns, pm_syntactic_sugar};
use crate::rewrite::graph_rewrite_bottom_up;
static PM_EARLY_MOPS: LazyLock<crate::TypedPatternMatcher> =
LazyLock::new(|| movement_op_patterns() + pm_syntactic_sugar());
let mut sink = graph_rewrite_bottom_up(&*PM_EARLY_MOPS, ast, &mut ());
tracing::debug!(
ast.pre = sink.tree(),
node_count = sink.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"pre-opt: movement ops + syntactic sugar complete"
);
let t_stage = std::time::Instant::now();
sink = graph_rewrite(pm_load_collapse(), sink, &mut ());
tracing::debug!(
ast.pre = sink.tree(),
node_count = sink.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"pre-opt: load collapse complete"
);
let t_stage = std::time::Instant::now();
let mut split_ctx = SplitRangesContext::default();
sink = graph_rewrite(&pm_split_ranges(), sink, &mut split_ctx);
sink = graph_rewrite(pm_flatten_range(), sink, &mut ());
tracing::debug!(
ast.pre = sink.tree(),
node_count = sink.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"pre-opt: split ranges complete"
);
let t_stage = std::time::Instant::now();
static PM_SYM_FLATTEN: LazyLock<crate::TypedPatternMatcher> = LazyLock::new(|| sym().clone() + pm_flatten_range());
sink = graph_rewrite(&*PM_SYM_FLATTEN, sink, &mut ());
tracing::debug!(
ast.pre = sink.tree(),
node_count = sink.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"pre-opt: symbolic + flatten complete"
);
let t_stage = std::time::Instant::now();
static PM_SIMPLIFY_FLATTEN: LazyLock<crate::TypedPatternMatcher> =
LazyLock::new(|| pm_flatten_range() + pm_simplify_ranges());
sink = graph_rewrite(&*PM_SIMPLIFY_FLATTEN, sink, &mut ());
tracing::debug!(
ast.pre = sink.tree(),
node_count = sink.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"pre-opt: simplify ranges complete"
);
sink
}
pub fn optimize_kernel_with_config(
ast: Arc<morok_ir::UOp>,
renderer: &Renderer,
config: &OptimizerConfig,
) -> Arc<morok_ir::UOp> {
let pre_optimized = apply_pre_optimization(ast);
let optimized = match config.strategy {
OptStrategy::None => pre_optimized, OptStrategy::Heuristic => optimize_heuristic(pre_optimized, renderer, &config.heuristics),
OptStrategy::Beam { .. } => {
optimize_heuristic(pre_optimized, renderer, &config.heuristics)
}
};
apply_post_optimization_with_renderer(optimized, Some(renderer))
}
pub fn optimize_kernel_with_strategy(
ast: Arc<morok_ir::UOp>,
renderer: &Renderer,
strategy: OptStrategy,
) -> Arc<morok_ir::UOp> {
let config = OptimizerConfig { strategy, ..Default::default() };
optimize_kernel_with_config(ast, renderer, &config)
}
pub fn optimize_kernel_beam<F>(
ast: Arc<morok_ir::UOp>,
renderer: &Renderer,
config: &BeamConfig,
compile_and_time: F,
) -> Result<BeamResult, error::OptError>
where
F: Fn(&Scheduler) -> Option<std::time::Duration> + Sync,
{
let pre_optimized = apply_pre_optimization(ast);
let mut scheduler = Scheduler::new(pre_optimized, renderer.clone());
let _ = scheduler.convert_loop_to_global();
beam::beam_search_cached(scheduler, config, compile_and_time)
}
pub fn prepare_scheduler(ast: Arc<morok_ir::UOp>, renderer: &Renderer) -> Scheduler {
let pre_optimized = apply_pre_optimization(ast);
let mut scheduler = Scheduler::new(pre_optimized, renderer.clone());
let _ = scheduler.convert_loop_to_global(); scheduler
}
fn optimize_heuristic(ast: Arc<morok_ir::UOp>, renderer: &Renderer, config: &HeuristicsConfig) -> Arc<morok_ir::UOp> {
let mut scheduler = Scheduler::new(ast, renderer.clone());
let _ = scheduler.convert_loop_to_global(); let _ = scheduler.convert_outer_to_loop();
heuristics::hand_coded_optimizations(&mut scheduler, config);
scheduler.get_optimized_ast(None)
}