Skip to main content

morok_schedule/optimizer/
mod.rs

1//! Kernel optimization layer for morok-schedule.
2//!
3//! This module implements hardware-aware kernel optimization based on Tinygrad's approach.
4//! It provides a `Scheduler` that applies optimization primitives (OptOps) to transform
5//! kernel execution for better performance on specific backends.
6//!
7//! # Architecture
8//!
9//! The optimization process follows this flow:
10//!
11//! 1. **Initialization**: Create `Scheduler` from UOp AST + `Renderer` (backend capabilities)
12//! 2. **Initial Transform**: Convert eligible LOOP axes to GLOBAL (parallelization)
13//! 3. **Optimization**: Apply `Opt` operations via `apply_opt()`
14//!    - UPCAST: Vectorization (SIMD)
15//!    - LOCAL: GPU workgroup dimensions (shared memory)
16//!    - UNROLL: Loop unrolling for reductions
17//!    - GROUP: Two-stage reductions with synchronization
18//!    - TC: Tensor core acceleration
19//!    - PADTO, SWAP, THREAD, NOLOCALS: Layout and configuration
20//! 4. **Finalization**: Extract optimized AST with `get_optimized_ast()`
21//!
22//! # Optimization Strategies
23//!
24//! - **Hand-coded heuristics** (`heuristics` module): Fast, reasonable performance
25//! - **Beam search** (`beam` module, optional): Slow, ML-quality performance
26//!
27//! # Example
28//!
29//! ```ignore
30//! use morok_schedule::optimizer::{Scheduler, Renderer, Opt, OptOps};
31//!
32//! // Create scheduler with CUDA backend
33//! let renderer = Renderer::cuda();
34//! let mut scheduler = Scheduler::new(kernel_ast, renderer);
35//!
36//! // Apply optimizations
37//! scheduler.convert_loop_to_global();
38//! scheduler.apply_opt(Opt::upcast(0, 4), true)?; // Vectorize axis 0 by 4
39//! scheduler.apply_opt(Opt::local(1, 16), true)?; // Local memory for axis 1
40//!
41//! // Get optimized kernel
42//! let optimized_ast = scheduler.get_optimized_ast(None);
43//! ```
44
45pub mod beam;
46pub mod config;
47pub mod error;
48pub mod heuristics;
49pub mod kernel_info;
50pub mod opts;
51pub mod renderer;
52pub mod scheduler;
53pub mod tc;
54pub mod types;
55
56// Re-exports
57pub use beam::{BeamResult, beam_search, beam_search_cached, beam_search_with_timeout, clear_cache, replay_opts};
58pub use config::{BeamConfig, HeuristicsConfig, OptStrategy, OptimizerConfig, TcOpt as TcOptLevel, TcSelect, TcUsage};
59pub use error::OptError;
60pub use heuristics::hand_coded_optimizations;
61pub use kernel_info::KernelInfo;
62pub use opts::apply_opt;
63pub use renderer::{Renderer, TcOpt, TensorCore};
64pub use scheduler::Scheduler;
65#[cfg(test)]
66pub use scheduler::clear_kernel_name_counts;
67pub use types::{AxisType, Opt, OptArg, OptOps};
68
69use crate::devectorize::{
70    Fp8DecompCtx, bool_storage_patterns, pm_float_decomp, pm_float_decomp_store, pm_reduce, pm_render,
71    pm_wmma_accumulate,
72};
73use crate::gpudims::pm_add_gpudims;
74// pm_linearize_multi_index removed: Tinygrad keeps multi-index INDEX through the pipeline.
75// Codegen backends compute flat addresses at render time.
76use crate::rangeify::patterns::{
77    pm_add_loads, pm_comparison_negations, pm_demorgan, pm_div_to_shr, pm_erf_decomposition, pm_fdiv_to_mul,
78    pm_fma_decomposition, pm_load_collapse, pm_mod_to_and, pm_mul_to_shl, pm_neg_from_mul, pm_shl_add_to_mulacc,
79    pm_threefry_decomp, rangeify_codegen_with_kernel_ctx,
80};
81use crate::rangeify::pm_add_buffers_local_patterns;
82use crate::rangeify::transforms::{pm_flatten_range, pm_simplify_ranges, pm_split_ranges};
83use crate::rewrite::graph_rewrite;
84use crate::symbolic::patterns::{gep_pushing_patterns, sym, symbolic, symbolic_simple};
85use std::sync::{Arc, LazyLock};
86
87/// Apply optimizations to a kernel AST.
88///
89/// This is the main entry point for optimization in the tensor pipeline.
90/// Uses environment variables for configuration (see `OptimizerConfig::from_env`).
91///
92/// # Pipeline
93///
94/// 1. **Symbolic simplification** - Constant folding, identities, DCE
95/// 2. **Loop→Global conversion** - Enable GPU parallelization
96/// 3. **Hand-coded heuristics** - Vectorization, unrolling, tiling
97///
98/// # Arguments
99///
100/// * `ast` - The kernel AST (inner AST from KERNEL op)
101/// * `renderer` - Backend capabilities descriptor
102///
103/// # Returns
104///
105/// Optimized AST with transformations applied.
106///
107/// # Environment Variables
108///
109/// * `MOROK_NOOPT=1` - Disable all optimizations (for debugging)
110/// * `MOROK_BEAM=N` - Use beam search with width N (future)
111pub fn optimize_kernel(ast: Arc<morok_ir::UOp>, renderer: &Renderer) -> Arc<morok_ir::UOp> {
112    optimize_kernel_with_config(ast, renderer, &OptimizerConfig::from_env())
113}
114
115/// Apply post-optimization passes to kernel AST.
116///
117/// These passes run AFTER heuristic/beam optimization and BEFORE codegen:
118/// - pm_add_loads: Extract LOAD ops from INDEX
119/// - pre_expand: Convert Range(Unroll/Upcast) → UNROLL, expand operations
120/// - pm_add_gpudims (GPU only): Convert GLOBAL/LOCAL RANGE to SPECIAL thread indices
121/// - devectorize: Combined pass (sym + devec + load_store_folding + correct_load_store + indexing)
122/// - bool_storage_patterns: Convert bool LOAD/STORE to uint8
123///
124/// NOTE: We do NOT apply FMA decomposition (a*b+c → MulAcc). Following Tinygrad's
125/// approach, we let LLVM's optimizer fuse MUL+ADD into FMA when beneficial.
126///
127/// # Arguments
128///
129/// * `ast` - The kernel AST to optimize
130///
131/// Called by both heuristic and beam search paths for consistent behavior.
132/// For GPU pipelines, use `apply_post_optimization_with_renderer` to enable GPU dimension injection.
133#[tracing::instrument(skip_all)]
134pub fn apply_post_optimization(ast: Arc<morok_ir::UOp>) -> Arc<morok_ir::UOp> {
135    apply_post_optimization_with_renderer(ast, None)
136}
137
138/// Apply post-optimization passes with renderer context.
139///
140/// Same as `apply_post_optimization` but accepts an optional renderer for GPU-specific passes.
141/// When a renderer with GPU capabilities (has_local) is provided, `pm_add_gpudims` is applied
142/// to convert GLOBAL/LOCAL RANGE operations to SPECIAL thread indices.
143///
144/// # Arguments
145///
146/// * `ast` - The kernel AST to optimize
147/// * `renderer` - Optional renderer for GPU dimension injection
148#[tracing::instrument(skip_all)]
149pub fn apply_post_optimization_with_renderer(
150    ast: Arc<morok_ir::UOp>,
151    renderer: Option<&Renderer>,
152) -> Arc<morok_ir::UOp> {
153    // Save metadata before graph_rewrite destroys it (e.g., KernelInfo with kernel name)
154    let saved_metadata = ast.metadata_raw();
155
156    tracing::debug!(ast.initial = ast.tree(), node_count = ast.node_count(), "kernel initial");
157
158    // Tinygrad keeps multi-index INDEX through the pipeline — no linearization here.
159    // Codegen backends compute flat addresses at render time via render_linearize_multi_index.
160
161    // =========================================================================
162    // Stage 8: Post-opt symbolic + WHERE movement (Tinygrad: sym + pm_move_where_on_load)
163    // This MUST run BEFORE expander to optimize conditionals before expansion.
164    // =========================================================================
165    let t_stage = std::time::Instant::now();
166    // Tinygrad: sym + pm_move_where_on_load (pm_move_where_on_load only at this stage, not global)
167    static POST_OPT_SYM: LazyLock<crate::TypedPatternMatcher> =
168        LazyLock::new(|| sym().clone() + crate::symbolic::patterns::pm_move_where_on_load());
169    let with_symbolic = graph_rewrite(&*POST_OPT_SYM, ast, &mut ());
170    tracing::debug!(
171        ast.optimized = with_symbolic.tree(),
172        node_count = with_symbolic.node_count(),
173        elapsed_ms = t_stage.elapsed().as_millis() as u64,
174        "Stage 8: after post-opt symbolic"
175    );
176
177    // =========================================================================
178    // Stage 9: Expander (Tinygrad: sym + pm_pre_expander + pm_group_for_reduce + expander)
179    // =========================================================================
180    // UNROLL expansion: Expand UNROLL ops to vectorized operations (Tinygrad expander.py)
181    // CRITICAL: Must run BEFORE pm_reduce so that REDUCE sees its actual vectorized dtype.
182    // In Tinygrad, expander runs first, then pm_reduce sees the expanded REDUCE with vec2 dtype.
183    // This allows reduce_to_acc to create accumulators with the correct vector dtype.
184    let t_stage = std::time::Instant::now();
185    let expanded = crate::expand::pre_expand(&with_symbolic);
186    tracing::debug!(
187        ast.optimized = expanded.tree(),
188        node_count = expanded.node_count(),
189        elapsed_ms = t_stage.elapsed().as_millis() as u64,
190        "Stage 9: after pre_expand"
191    );
192
193    // =========================================================================
194    // Stage 10: Add local buffers (Tinygrad: pm_add_buffers_local + rangeify_codegen)
195    // =========================================================================
196    // Converts BUFFERIZE(Local) → DEFINE_LOCAL + STORE + LOAD for GROUP_REDUCE.
197    // Also strips leftover CONTIGUOUS and NOOP nodes.
198    // Must run AFTER expander (which creates BUFFERIZE_LOCAL) and BEFORE pm_reduce.
199    //
200    // CRITICAL: Combine pm_add_buffers_local + rangeify_codegen in a SINGLE pass
201    // (like Tinygrad) to ensure CONTIGUOUS is stripped BEFORE bufferize_to_store
202    // sees it. Otherwise CONTIGUOUS(BUFFER) becomes the STORE value directly,
203    // which fails codegen because STORE expects a value, not a buffer pointer.
204    // Helper closure: check for UNROLL(GROUP) in graph
205    let check_unroll_group = |label: &str, root: &Arc<morok_ir::UOp>| {
206        for node in root.toposort() {
207            if let morok_ir::Op::Unroll { src, unroll_axes, .. } = node.op()
208                && matches!(src.op(), morok_ir::Op::Group { .. })
209            {
210                tracing::error!(id = node.id, axes = ?unroll_axes, stage = label, "UNROLL(GROUP) found!");
211            }
212        }
213    };
214
215    let t_stage = std::time::Instant::now();
216    let with_local_buffers = {
217        let mut buf_ctx = crate::rangeify::KernelContext::new();
218        static PM_LOCAL_BUF: LazyLock<crate::TypedPatternMatcher<crate::rangeify::KernelContext>> =
219            LazyLock::new(|| pm_add_buffers_local_patterns() + rangeify_codegen_with_kernel_ctx());
220        graph_rewrite(&*PM_LOCAL_BUF, expanded, &mut buf_ctx)
221    };
222    tracing::debug!(
223        ast.optimized = with_local_buffers.tree(),
224        node_count = with_local_buffers.node_count(),
225        elapsed_ms = t_stage.elapsed().as_millis() as u64,
226        "Stage 10: after add local buffers"
227    );
228    if cfg!(debug_assertions) {
229        check_unroll_group("after_add_local_buffers", &with_local_buffers);
230    }
231
232    let t_stage = std::time::Instant::now();
233    static PM_REDUCE_COMBINED: LazyLock<crate::TypedPatternMatcher<crate::devectorize::ReduceContext>> =
234        LazyLock::new(|| pm_reduce() + pm_wmma_accumulate().with_context() + gep_pushing_patterns().with_context());
235    let mut reduce_ctx = crate::devectorize::ReduceContext::default();
236    let reduced = graph_rewrite(&*PM_REDUCE_COMBINED, with_local_buffers, &mut reduce_ctx);
237    tracing::debug!(
238        ast.optimized = reduced.tree(),
239        node_count = reduced.node_count(),
240        elapsed_ms = t_stage.elapsed().as_millis() as u64,
241        "after pm_reduce"
242    );
243    if cfg!(debug_assertions) {
244        check_unroll_group("after_pm_reduce", &reduced);
245    }
246
247    let t_stage = std::time::Instant::now();
248    let with_gpudims = if let Some(ren) = renderer {
249        if ren.has_local { graph_rewrite(&pm_add_gpudims(), reduced, &mut ren.clone()) } else { reduced }
250    } else {
251        reduced
252    };
253    tracing::debug!(
254        ast.optimized = with_gpudims.tree(),
255        node_count = with_gpudims.node_count(),
256        elapsed_ms = t_stage.elapsed().as_millis() as u64,
257        "after pm_add_gpudims"
258    );
259    if cfg!(debug_assertions) {
260        check_unroll_group("after_pm_add_gpudims", &with_gpudims);
261    }
262
263    let t_stage = std::time::Instant::now();
264    let with_loads = graph_rewrite(pm_add_loads(), with_gpudims, &mut ());
265    tracing::debug!(
266        ast.optimized = with_loads.tree(),
267        node_count = with_loads.node_count(),
268        elapsed_ms = t_stage.elapsed().as_millis() as u64,
269        "after pm_add_loads"
270    );
271    if cfg!(debug_assertions) {
272        check_unroll_group("after_pm_add_loads", &with_loads);
273        // Also check for any UNROLL or CONTRACT
274        for node in with_loads.toposort() {
275            if let morok_ir::Op::Unroll { src, unroll_axes, .. } = node.op() {
276                tracing::error!(
277                    id = node.id,
278                    src_op = src.op().as_ref(),
279                    axes = ?unroll_axes,
280                    "BEFORE devectorize: found UNROLL!"
281                );
282            }
283            if let morok_ir::Op::Contract { src, upcast_ranges, .. } = node.op() {
284                tracing::error!(
285                    id = node.id,
286                    src_op = src.op().as_ref(),
287                    axes = ?upcast_ranges,
288                    "BEFORE devectorize: found CONTRACT!"
289                );
290            }
291        }
292    }
293
294    // ALU devectorization happens inside devectorize() Phase 1, alongside expand_index
295    // and full symbolic (including gep_pushing). This matches Tinygrad's structure where
296    // no_vectorized_alu runs in the same pass as load_store_folding (step 14).
297    // Previously, an isolated pass here combined no_vectorized_alu + gep_pushing without
298    // load/store folding, causing graph explosion on wide VECTORIZE nodes (VECTORIZE(135)).
299    // Tinygrad Stage 14: devectorize — single combined pass handles ALL devectorization
300    // including bool ALU (via no_vectorized_alu). No separate pm_bool_devectorize or
301    // pm_reduce_devectorize passes — matching Tinygrad's pipeline exactly.
302    let t_stage = std::time::Instant::now();
303    let devectorized = crate::devectorize::devectorize(&with_loads);
304    tracing::debug!(
305        ast.optimized = devectorized.tree(),
306        node_count = devectorized.node_count(),
307        elapsed_ms = t_stage.elapsed().as_millis() as u64,
308        "after devectorize"
309    );
310    check_unroll_group("after_devectorize", &devectorized);
311
312    // Tinygrad Stage 15: pm_lower_index_dtype + load_store_indexing + gep_pushing
313    let t_stage = std::time::Instant::now();
314    static PM_LOWER_COMBINED: LazyLock<crate::TypedPatternMatcher> = LazyLock::new(|| {
315        crate::symbolic::pm_lower_index_dtype()
316            + crate::devectorize::load_store_indexing_patterns()
317            + gep_pushing_patterns()
318    });
319    let with_lowered_idx = graph_rewrite(&*PM_LOWER_COMBINED, devectorized, &mut ());
320    tracing::debug!(
321        ast.optimized = with_lowered_idx.tree(),
322        node_count = with_lowered_idx.node_count(),
323        elapsed_ms = t_stage.elapsed().as_millis() as u64,
324        "after pm_lower_index_dtype"
325    );
326    check_unroll_group("after_pm_lower_index_dtype", &with_lowered_idx);
327
328    // Tinygrad: symbolic (step 16) — full symbolic (includes gep_pushing, div_and_mod, etc.)
329    let t_stage = std::time::Instant::now();
330    static POST_INDEX_SYM: LazyLock<crate::TypedPatternMatcher> = LazyLock::new(|| symbolic().clone());
331    let with_lowered_idx = graph_rewrite(&*POST_INDEX_SYM, with_lowered_idx, &mut ());
332    tracing::debug!(
333        ast.optimized = with_lowered_idx.tree(),
334        node_count = with_lowered_idx.node_count(),
335        elapsed_ms = t_stage.elapsed().as_millis() as u64,
336        "after post-index symbolic"
337    );
338
339    // =========================================================================
340    // Stage 18-19: Decompositions + Render (Tinygrad: pm_decomp + pm_render in one pass)
341    // =========================================================================
342    let t_stage = std::time::Instant::now();
343    static PM_FINAL: LazyLock<crate::TypedPatternMatcher> =
344        LazyLock::new(|| symbolic_simple() + get_late_rewrite_patterns() + pm_render());
345    let rendered = graph_rewrite(&*PM_FINAL, with_lowered_idx, &mut ());
346    tracing::debug!(
347        ast.optimized = rendered.tree(),
348        node_count = rendered.node_count(),
349        elapsed_ms = t_stage.elapsed().as_millis() as u64,
350        "Stage 18-19: after pm_decomp + pm_render"
351    );
352
353    // Merge sibling ENDs that share the same reduce ranges.
354    // pm_decomp+pm_render can create new sibling ENDs (e.g. by rewriting computations
355    // inside an END differently per vector lane). merge_reduce_ends ran earlier in
356    // pm_reduce but only caught ENDs that existed at that point.
357    let t_merge = std::time::Instant::now();
358    let rendered = crate::devectorize::merge_sibling_ends(&rendered);
359    tracing::debug!(
360        ast.optimized = rendered.tree(),
361        node_count = rendered.node_count(),
362        elapsed_ms = t_merge.elapsed().as_millis() as u64,
363        "after merge_sibling_ends"
364    );
365
366    // FP8 float decomposition: promote FP8 computation to Float16 via bitwise conversion.
367    // Uses graph_rewrite_with_bpm: STORE pattern in bpm (sees ORIGINAL children to detect
368    // FP8 buffer ptrs), all other patterns in pm (sees OPTIMIZED children).
369    // Run once per FP8 type. Tinygrad: codegen/__init__.py:97-99
370    let t_stage = std::time::Instant::now();
371    let fp8_pm = pm_float_decomp();
372    let fp8_bpm = pm_float_decomp_store();
373    let mut fp8_decomposed = rendered;
374    for (fr, to) in [
375        (morok_dtype::ScalarDType::FP8E5M2, morok_dtype::ScalarDType::Float16),
376        (morok_dtype::ScalarDType::FP8E4M3, morok_dtype::ScalarDType::Float16),
377    ] {
378        let mut ctx = Fp8DecompCtx { from: fr, to };
379        fp8_decomposed = morok_ir::rewrite::graph_rewrite_with_bpm(&fp8_pm, &fp8_bpm, fp8_decomposed, &mut ctx);
380    }
381    tracing::debug!(
382        ast.optimized = fp8_decomposed.tree(),
383        node_count = fp8_decomposed.node_count(),
384        elapsed_ms = t_stage.elapsed().as_millis() as u64,
385        "after pm_float_decomp"
386    );
387
388    let t_stage = std::time::Instant::now();
389    let bs = graph_rewrite(bool_storage_patterns(), fp8_decomposed, &mut ());
390    tracing::debug!(
391        ast.optimized = bs.tree(),
392        node_count = bs.node_count(),
393        elapsed_ms = t_stage.elapsed().as_millis() as u64,
394        "after bool_storage_pattern"
395    );
396
397    // Re-attach metadata (e.g., KernelInfo) that was lost during graph rewrites
398    match saved_metadata {
399        Some(meta) => bs.with_metadata_raw(meta),
400        None => bs,
401    }
402}
403
404/// Late rewrite patterns for algebraic decompositions.
405///
406/// Based on Tinygrad's `get_late_rewrite_patterns` (decompositions.py:438-480).
407///
408/// Returns patterns for:
409/// - MULACC (FMA): `a*b+c → MulAcc(a,b,c)` for float types
410/// - MOD → AND: `x % 2^n → x & (2^n-1)` for power-of-two modulus
411/// - MUL → SHL: `x * 2^n → x << n` for power-of-two multiplier
412/// - NEG from MUL: `x * -1 → NEG(x)`
413/// - Fast integer division (magic number multiplication)
414fn get_late_rewrite_patterns() -> &'static crate::TypedPatternMatcher {
415    // All current backends support MAX and SQRT natively (LLVM, CUDA, Metal).
416    // When we add backends that lack support, this should take a capability set
417    // (like Tinygrad's `ops: tuple[Ops, ...]`) and conditionally include patterns.
418    static CACHED: LazyLock<crate::TypedPatternMatcher> = LazyLock::new(|| {
419        pm_fma_decomposition()
420            + pm_erf_decomposition()
421            + pm_mod_to_and()
422            + pm_mul_to_shl()
423            + pm_div_to_shr()
424            + pm_fdiv_to_mul()
425            + pm_neg_from_mul()
426            + pm_demorgan()
427            + pm_shl_add_to_mulacc()
428            + pm_threefry_decomp()
429            + pm_comparison_negations()
430            + crate::symbolic::fast_division_patterns()
431            + pm_mod_to_idiv()
432    });
433    &CACHED
434}
435
436/// MOD → IDIV decomposition (Tinygrad decompositions.py:457).
437///
438/// `x % d → x - d*(x//d)` for non-power-of-2 constant divisors.
439/// Runs AFTER fast_division_patterns so the resulting IDIV gets decomposed
440/// to magic-number multiplication. Without this, standalone MOD nodes
441/// for non-power-of-2 divisors survive to codegen unlowered.
442fn pm_mod_to_idiv() -> &'static crate::TypedPatternMatcher {
443    crate::cached_patterns! {
444        Mod(x, d @const(d_val))
445            if x.dtype().is_int()
446            && matches!(d_val.try_int(), Some(v) if v > 1 && !((v as u64).is_power_of_two()))
447            => {
448                // x % d → x - d * (x // d)
449                let div = x.idiv(d);
450                let mul = d.try_mul(&div).ok()?;
451                x.try_sub(&mul).ok()
452            },
453    }
454}
455
456/// Apply per-kernel pre-optimization passes.
457///
458/// These stages run BEFORE heuristic/beam optimization, per-kernel
459/// (Tinygrad: inside `full_rewrite_to_sink()`, codegen/__init__.py:28-51).
460///
461/// Stages:
462/// 0. Movement ops + syntactic sugar (`pm_mops + pm_syntactic_sugar`, bottom-up)
463/// 1. Load collapse (`pm_load_collapse`)
464/// 2. Split ranges + flatten (`pm_split_ranges + pm_flatten_range`)
465/// 3. Symbolic + flatten (`sym + pm_flatten_range`)
466/// 4. Simplify ranges (`pm_simplify_ranges`)
467///
468/// Called by both heuristic and beam search paths.
469#[tracing::instrument(skip_all)]
470pub fn apply_pre_optimization(ast: Arc<morok_ir::UOp>) -> Arc<morok_ir::UOp> {
471    tracing::debug!(ast.initial = ast.tree(), node_count = ast.node_count(), "kernel initial");
472
473    use crate::rangeify::transforms::SplitRangesContext;
474
475    let t_stage = std::time::Instant::now();
476    use crate::rangeify::patterns::{movement_op_patterns, pm_syntactic_sugar};
477    use crate::rewrite::graph_rewrite_bottom_up;
478    static PM_EARLY_MOPS: LazyLock<crate::TypedPatternMatcher> =
479        LazyLock::new(|| movement_op_patterns() + pm_syntactic_sugar());
480    let mut sink = graph_rewrite_bottom_up(&*PM_EARLY_MOPS, ast, &mut ());
481    tracing::debug!(
482        ast.pre = sink.tree(),
483        node_count = sink.node_count(),
484        elapsed_ms = t_stage.elapsed().as_millis() as u64,
485        "pre-opt: movement ops + syntactic sugar complete"
486    );
487
488    let t_stage = std::time::Instant::now();
489    sink = graph_rewrite(pm_load_collapse(), sink, &mut ());
490    tracing::debug!(
491        ast.pre = sink.tree(),
492        node_count = sink.node_count(),
493        elapsed_ms = t_stage.elapsed().as_millis() as u64,
494        "pre-opt: load collapse complete"
495    );
496
497    let t_stage = std::time::Instant::now();
498    let mut split_ctx = SplitRangesContext::default();
499    sink = graph_rewrite(&pm_split_ranges(), sink, &mut split_ctx);
500    sink = graph_rewrite(pm_flatten_range(), sink, &mut ());
501    tracing::debug!(
502        ast.pre = sink.tree(),
503        node_count = sink.node_count(),
504        elapsed_ms = t_stage.elapsed().as_millis() as u64,
505        "pre-opt: split ranges complete"
506    );
507
508    let t_stage = std::time::Instant::now();
509    // Tinygrad: sym + pm_flatten_range (pre-opt uses full sym tier)
510    static PM_SYM_FLATTEN: LazyLock<crate::TypedPatternMatcher> = LazyLock::new(|| sym().clone() + pm_flatten_range());
511    sink = graph_rewrite(&*PM_SYM_FLATTEN, sink, &mut ());
512    tracing::debug!(
513        ast.pre = sink.tree(),
514        node_count = sink.node_count(),
515        elapsed_ms = t_stage.elapsed().as_millis() as u64,
516        "pre-opt: symbolic + flatten complete"
517    );
518
519    let t_stage = std::time::Instant::now();
520    static PM_SIMPLIFY_FLATTEN: LazyLock<crate::TypedPatternMatcher> =
521        LazyLock::new(|| pm_flatten_range() + pm_simplify_ranges());
522    sink = graph_rewrite(&*PM_SIMPLIFY_FLATTEN, sink, &mut ());
523    tracing::debug!(
524        ast.pre = sink.tree(),
525        node_count = sink.node_count(),
526        elapsed_ms = t_stage.elapsed().as_millis() as u64,
527        "pre-opt: simplify ranges complete"
528    );
529
530    sink
531}
532
533/// Apply optimizations with explicit configuration.
534///
535/// Use this when you need explicit control over the optimization settings.
536///
537/// Note: For beam search strategy, this falls back to heuristics because
538/// beam search requires a `compile_and_time` function from the runtime.
539/// Use `optimize_kernel_beam()` for actual beam search optimization.
540pub fn optimize_kernel_with_config(
541    ast: Arc<morok_ir::UOp>,
542    renderer: &Renderer,
543    config: &OptimizerConfig,
544) -> Arc<morok_ir::UOp> {
545    // Pre-optimization: per-kernel stages (Tinygrad: full_rewrite_to_sink)
546    let pre_optimized = apply_pre_optimization(ast);
547
548    let optimized = match config.strategy {
549        OptStrategy::None => pre_optimized, // No heuristic optimization, but post-optimization still needed
550        OptStrategy::Heuristic => optimize_heuristic(pre_optimized, renderer, &config.heuristics),
551        OptStrategy::Beam { .. } => {
552            // Beam search requires a compile_and_time function.
553            // Use optimize_kernel_beam() for actual beam search.
554            // Fall back to heuristics for the simple API.
555            optimize_heuristic(pre_optimized, renderer, &config.heuristics)
556        }
557    };
558
559    // apply_post_optimization contains correctness transforms (pm_add_loads wraps INDEX
560    // with LOAD for arithmetic ops) and must run even when optimizations are disabled.
561    // Pass the renderer to enable GPU dimension injection for GPU backends.
562
563    apply_post_optimization_with_renderer(optimized, Some(renderer))
564}
565
566/// Apply optimizations with explicit strategy selection (legacy API).
567///
568/// Prefer `optimize_kernel_with_config` for new code.
569pub fn optimize_kernel_with_strategy(
570    ast: Arc<morok_ir::UOp>,
571    renderer: &Renderer,
572    strategy: OptStrategy,
573) -> Arc<morok_ir::UOp> {
574    let config = OptimizerConfig { strategy, ..Default::default() };
575    optimize_kernel_with_config(ast, renderer, &config)
576}
577
578/// Apply beam search optimization with custom timing function.
579///
580/// This is the primary entry point for beam search auto-tuning. It requires
581/// a `compile_and_time` function that compiles a scheduler state and returns
582/// its execution timing.
583///
584/// # Arguments
585///
586/// * `ast` - The kernel AST to optimize
587/// * `renderer` - Backend capabilities descriptor
588/// * `config` - Beam search configuration
589/// * `compile_and_time` - Function to compile and time a scheduler
590///
591/// # Returns
592///
593/// Result containing `BeamResult` with optimized scheduler and metrics.
594///
595/// # Example
596///
597/// ```ignore
598/// use morok_schedule::optimizer::{optimize_kernel_beam, BeamConfig, Renderer};
599/// use morok_runtime::{BenchmarkConfig, benchmark_kernel};
600///
601/// let config = BeamConfig::from_env();
602/// let renderer = Renderer::cpu();
603///
604/// let compile_and_time = |scheduler: &Scheduler| -> Option<Duration> {
605///     let ast = scheduler.get_optimized_ast(None);
606///     let kernel = compile_kernel(&ast)?;
607///     let result = benchmark_kernel(&kernel, &buffers, &vars, &bench_config).ok()?;
608///     Some(result.min)
609/// };
610///
611/// let result = optimize_kernel_beam(ast, &renderer, &config, compile_and_time)?;
612/// let optimized_ast = result.scheduler.get_optimized_ast(None);
613/// ```
614pub fn optimize_kernel_beam<F>(
615    ast: Arc<morok_ir::UOp>,
616    renderer: &Renderer,
617    config: &BeamConfig,
618    compile_and_time: F,
619) -> Result<BeamResult, error::OptError>
620where
621    F: Fn(&Scheduler) -> Option<std::time::Duration> + Sync,
622{
623    // Step 0: Per-kernel pre-optimization (Tinygrad: full_rewrite_to_sink)
624    let pre_optimized = apply_pre_optimization(ast);
625
626    // Step 1: Create scheduler (AST already simplified by apply_pre_optimization Stage 3)
627    let mut scheduler = Scheduler::new(pre_optimized, renderer.clone());
628
629    // Step 2: Convert loops to global (for GPU parallelization)
630    let _ = scheduler.convert_loop_to_global();
631
632    // Step 4: Run beam search (with caching)
633    beam::beam_search_cached(scheduler, config, compile_and_time)
634}
635
636/// Create a scheduler ready for optimization without applying any opts.
637///
638/// This is useful when you want to manually control the optimization process
639/// or use beam search with custom logic.
640///
641/// # Arguments
642///
643/// * `ast` - The kernel AST
644/// * `renderer` - Backend capabilities descriptor
645///
646/// # Returns
647///
648/// A `Scheduler` with loops converted to globals (if applicable).
649pub fn prepare_scheduler(ast: Arc<morok_ir::UOp>, renderer: &Renderer) -> Scheduler {
650    let pre_optimized = apply_pre_optimization(ast);
651    let mut scheduler = Scheduler::new(pre_optimized, renderer.clone());
652    let _ = scheduler.convert_loop_to_global(); // GPU: LOOP→GLOBAL
653    // Note: Don't apply threading here - let beam search explore THREAD actions naturally.
654    // Heuristics apply threading via hand_coded_optimizations() with config.thread_count.
655    scheduler
656}
657
658/// Apply heuristic-based optimizations.
659fn optimize_heuristic(ast: Arc<morok_ir::UOp>, renderer: &Renderer, config: &HeuristicsConfig) -> Arc<morok_ir::UOp> {
660    // Step 1: Create scheduler (AST already simplified by apply_pre_optimization Stage 3)
661    let mut scheduler = Scheduler::new(ast, renderer.clone());
662
663    // Step 3: Convert axes for parallelization/vectorization
664    let _ = scheduler.convert_loop_to_global(); // GPU: LOOP→GLOBAL
665    let _ = scheduler.convert_outer_to_loop(); // CPU: OUTER→LOOP (enables UPCAST)
666
667    // Step 4: Apply hand-coded heuristics with config
668    heuristics::hand_coded_optimizations(&mut scheduler, config);
669
670    // Step 5: Extract optimized AST
671    scheduler.get_optimized_ast(None)
672}