morok_schedule/optimizer/mod.rs
1//! Kernel optimization layer for morok-schedule.
2//!
3//! This module implements hardware-aware kernel optimization based on Tinygrad's approach.
4//! It provides a `Scheduler` that applies optimization primitives (OptOps) to transform
5//! kernel execution for better performance on specific backends.
6//!
7//! # Architecture
8//!
9//! The optimization process follows this flow:
10//!
11//! 1. **Initialization**: Create `Scheduler` from UOp AST + `Renderer` (backend capabilities)
12//! 2. **Initial Transform**: Convert eligible LOOP axes to GLOBAL (parallelization)
13//! 3. **Optimization**: Apply `Opt` operations via `apply_opt()`
14//! - UPCAST: Vectorization (SIMD)
15//! - LOCAL: GPU workgroup dimensions (shared memory)
16//! - UNROLL: Loop unrolling for reductions
17//! - GROUP: Two-stage reductions with synchronization
18//! - TC: Tensor core acceleration
19//! - PADTO, SWAP, THREAD, NOLOCALS: Layout and configuration
20//! 4. **Finalization**: Extract optimized AST with `get_optimized_ast()`
21//!
22//! # Optimization Strategies
23//!
24//! - **Hand-coded heuristics** (`heuristics` module): Fast, reasonable performance
25//! - **Beam search** (`beam` module, optional): Slow, ML-quality performance
26//!
27//! # Example
28//!
29//! ```ignore
30//! use morok_schedule::optimizer::{Scheduler, Renderer, Opt, OptOps};
31//!
32//! // Create scheduler with CUDA backend
33//! let renderer = Renderer::cuda();
34//! let mut scheduler = Scheduler::new(kernel_ast, renderer);
35//!
36//! // Apply optimizations
37//! scheduler.convert_loop_to_global();
38//! scheduler.apply_opt(Opt::upcast(0, 4), true)?; // Vectorize axis 0 by 4
39//! scheduler.apply_opt(Opt::local(1, 16), true)?; // Local memory for axis 1
40//!
41//! // Get optimized kernel
42//! let optimized_ast = scheduler.get_optimized_ast(None);
43//! ```
44
45pub mod beam;
46pub mod config;
47pub mod error;
48pub mod heuristics;
49pub mod kernel_info;
50pub mod opts;
51pub mod renderer;
52pub mod scheduler;
53pub mod tc;
54pub mod types;
55
56// Re-exports
57pub use beam::{BeamResult, beam_search, beam_search_cached, beam_search_with_timeout, clear_cache, replay_opts};
58pub use config::{BeamConfig, HeuristicsConfig, OptStrategy, OptimizerConfig, TcOpt as TcOptLevel, TcSelect, TcUsage};
59pub use error::OptError;
60pub use heuristics::hand_coded_optimizations;
61pub use kernel_info::KernelInfo;
62pub use opts::apply_opt;
63pub use renderer::{Renderer, TcOpt, TensorCore};
64pub use scheduler::Scheduler;
65#[cfg(test)]
66pub use scheduler::clear_kernel_name_counts;
67pub use types::{AxisType, Opt, OptArg, OptOps};
68
69use crate::devectorize::{
70 Fp8DecompCtx, bool_storage_patterns, pm_float_decomp, pm_float_decomp_store, pm_reduce, pm_render,
71 pm_wmma_accumulate,
72};
73use crate::gpudims::pm_add_gpudims;
74// pm_linearize_multi_index removed: Tinygrad keeps multi-index INDEX through the pipeline.
75// Codegen backends compute flat addresses at render time.
76use crate::rangeify::patterns::{
77 pm_add_loads, pm_comparison_negations, pm_demorgan, pm_div_to_shr, pm_erf_decomposition, pm_fdiv_to_mul,
78 pm_fma_decomposition, pm_load_collapse, pm_mod_to_and, pm_mul_to_shl, pm_neg_from_mul, pm_shl_add_to_mulacc,
79 pm_threefry_decomp, rangeify_codegen_with_kernel_ctx,
80};
81use crate::rangeify::pm_add_buffers_local_patterns;
82use crate::rangeify::transforms::{pm_flatten_range, pm_simplify_ranges, pm_split_ranges};
83use crate::rewrite::graph_rewrite;
84use crate::symbolic::patterns::{gep_pushing_patterns, sym, symbolic, symbolic_simple};
85use std::sync::{Arc, LazyLock};
86
87/// Apply optimizations to a kernel AST.
88///
89/// This is the main entry point for optimization in the tensor pipeline.
90/// Uses environment variables for configuration (see `OptimizerConfig::from_env`).
91///
92/// # Pipeline
93///
94/// 1. **Symbolic simplification** - Constant folding, identities, DCE
95/// 2. **Loop→Global conversion** - Enable GPU parallelization
96/// 3. **Hand-coded heuristics** - Vectorization, unrolling, tiling
97///
98/// # Arguments
99///
100/// * `ast` - The kernel AST (inner AST from KERNEL op)
101/// * `renderer` - Backend capabilities descriptor
102///
103/// # Returns
104///
105/// Optimized AST with transformations applied.
106///
107/// # Environment Variables
108///
109/// * `MOROK_NOOPT=1` - Disable all optimizations (for debugging)
110/// * `MOROK_BEAM=N` - Use beam search with width N (future)
111pub fn optimize_kernel(ast: Arc<morok_ir::UOp>, renderer: &Renderer) -> Arc<morok_ir::UOp> {
112 optimize_kernel_with_config(ast, renderer, &OptimizerConfig::from_env())
113}
114
115/// Apply post-optimization passes to kernel AST.
116///
117/// These passes run AFTER heuristic/beam optimization and BEFORE codegen:
118/// - pm_add_loads: Extract LOAD ops from INDEX
119/// - pre_expand: Convert Range(Unroll/Upcast) → UNROLL, expand operations
120/// - pm_add_gpudims (GPU only): Convert GLOBAL/LOCAL RANGE to SPECIAL thread indices
121/// - devectorize: Combined pass (sym + devec + load_store_folding + correct_load_store + indexing)
122/// - bool_storage_patterns: Convert bool LOAD/STORE to uint8
123///
124/// NOTE: We do NOT apply FMA decomposition (a*b+c → MulAcc). Following Tinygrad's
125/// approach, we let LLVM's optimizer fuse MUL+ADD into FMA when beneficial.
126///
127/// # Arguments
128///
129/// * `ast` - The kernel AST to optimize
130///
131/// Called by both heuristic and beam search paths for consistent behavior.
132/// For GPU pipelines, use `apply_post_optimization_with_renderer` to enable GPU dimension injection.
133#[tracing::instrument(skip_all)]
134pub fn apply_post_optimization(ast: Arc<morok_ir::UOp>) -> Arc<morok_ir::UOp> {
135 apply_post_optimization_with_renderer(ast, None)
136}
137
138/// Apply post-optimization passes with renderer context.
139///
140/// Same as `apply_post_optimization` but accepts an optional renderer for GPU-specific passes.
141/// When a renderer with GPU capabilities (has_local) is provided, `pm_add_gpudims` is applied
142/// to convert GLOBAL/LOCAL RANGE operations to SPECIAL thread indices.
143///
144/// # Arguments
145///
146/// * `ast` - The kernel AST to optimize
147/// * `renderer` - Optional renderer for GPU dimension injection
148#[tracing::instrument(skip_all)]
149pub fn apply_post_optimization_with_renderer(
150 ast: Arc<morok_ir::UOp>,
151 renderer: Option<&Renderer>,
152) -> Arc<morok_ir::UOp> {
153 // Save metadata before graph_rewrite destroys it (e.g., KernelInfo with kernel name)
154 let saved_metadata = ast.metadata_raw();
155
156 tracing::debug!(ast.initial = ast.tree(), node_count = ast.node_count(), "kernel initial");
157
158 // Tinygrad keeps multi-index INDEX through the pipeline — no linearization here.
159 // Codegen backends compute flat addresses at render time via render_linearize_multi_index.
160
161 // =========================================================================
162 // Stage 8: Post-opt symbolic + WHERE movement (Tinygrad: sym + pm_move_where_on_load)
163 // This MUST run BEFORE expander to optimize conditionals before expansion.
164 // =========================================================================
165 let t_stage = std::time::Instant::now();
166 // Tinygrad: sym + pm_move_where_on_load (pm_move_where_on_load only at this stage, not global)
167 static POST_OPT_SYM: LazyLock<crate::TypedPatternMatcher> =
168 LazyLock::new(|| sym().clone() + crate::symbolic::patterns::pm_move_where_on_load());
169 let with_symbolic = graph_rewrite(&*POST_OPT_SYM, ast, &mut ());
170 tracing::debug!(
171 ast.optimized = with_symbolic.tree(),
172 node_count = with_symbolic.node_count(),
173 elapsed_ms = t_stage.elapsed().as_millis() as u64,
174 "Stage 8: after post-opt symbolic"
175 );
176
177 // =========================================================================
178 // Stage 9: Expander (Tinygrad: sym + pm_pre_expander + pm_group_for_reduce + expander)
179 // =========================================================================
180 // UNROLL expansion: Expand UNROLL ops to vectorized operations (Tinygrad expander.py)
181 // CRITICAL: Must run BEFORE pm_reduce so that REDUCE sees its actual vectorized dtype.
182 // In Tinygrad, expander runs first, then pm_reduce sees the expanded REDUCE with vec2 dtype.
183 // This allows reduce_to_acc to create accumulators with the correct vector dtype.
184 let t_stage = std::time::Instant::now();
185 let expanded = crate::expand::pre_expand(&with_symbolic);
186 tracing::debug!(
187 ast.optimized = expanded.tree(),
188 node_count = expanded.node_count(),
189 elapsed_ms = t_stage.elapsed().as_millis() as u64,
190 "Stage 9: after pre_expand"
191 );
192
193 // =========================================================================
194 // Stage 10: Add local buffers (Tinygrad: pm_add_buffers_local + rangeify_codegen)
195 // =========================================================================
196 // Converts BUFFERIZE(Local) → DEFINE_LOCAL + STORE + LOAD for GROUP_REDUCE.
197 // Also strips leftover CONTIGUOUS and NOOP nodes.
198 // Must run AFTER expander (which creates BUFFERIZE_LOCAL) and BEFORE pm_reduce.
199 //
200 // CRITICAL: Combine pm_add_buffers_local + rangeify_codegen in a SINGLE pass
201 // (like Tinygrad) to ensure CONTIGUOUS is stripped BEFORE bufferize_to_store
202 // sees it. Otherwise CONTIGUOUS(BUFFER) becomes the STORE value directly,
203 // which fails codegen because STORE expects a value, not a buffer pointer.
204 // Helper closure: check for UNROLL(GROUP) in graph
205 let check_unroll_group = |label: &str, root: &Arc<morok_ir::UOp>| {
206 for node in root.toposort() {
207 if let morok_ir::Op::Unroll { src, unroll_axes, .. } = node.op()
208 && matches!(src.op(), morok_ir::Op::Group { .. })
209 {
210 tracing::error!(id = node.id, axes = ?unroll_axes, stage = label, "UNROLL(GROUP) found!");
211 }
212 }
213 };
214
215 let t_stage = std::time::Instant::now();
216 let with_local_buffers = {
217 let mut buf_ctx = crate::rangeify::KernelContext::new();
218 static PM_LOCAL_BUF: LazyLock<crate::TypedPatternMatcher<crate::rangeify::KernelContext>> =
219 LazyLock::new(|| pm_add_buffers_local_patterns() + rangeify_codegen_with_kernel_ctx());
220 graph_rewrite(&*PM_LOCAL_BUF, expanded, &mut buf_ctx)
221 };
222 tracing::debug!(
223 ast.optimized = with_local_buffers.tree(),
224 node_count = with_local_buffers.node_count(),
225 elapsed_ms = t_stage.elapsed().as_millis() as u64,
226 "Stage 10: after add local buffers"
227 );
228 if cfg!(debug_assertions) {
229 check_unroll_group("after_add_local_buffers", &with_local_buffers);
230 }
231
232 let t_stage = std::time::Instant::now();
233 static PM_REDUCE_COMBINED: LazyLock<crate::TypedPatternMatcher<crate::devectorize::ReduceContext>> =
234 LazyLock::new(|| pm_reduce() + pm_wmma_accumulate().with_context() + gep_pushing_patterns().with_context());
235 let mut reduce_ctx = crate::devectorize::ReduceContext::default();
236 let reduced = graph_rewrite(&*PM_REDUCE_COMBINED, with_local_buffers, &mut reduce_ctx);
237 tracing::debug!(
238 ast.optimized = reduced.tree(),
239 node_count = reduced.node_count(),
240 elapsed_ms = t_stage.elapsed().as_millis() as u64,
241 "after pm_reduce"
242 );
243 if cfg!(debug_assertions) {
244 check_unroll_group("after_pm_reduce", &reduced);
245 }
246
247 let t_stage = std::time::Instant::now();
248 let with_gpudims = if let Some(ren) = renderer {
249 if ren.has_local { graph_rewrite(&pm_add_gpudims(), reduced, &mut ren.clone()) } else { reduced }
250 } else {
251 reduced
252 };
253 tracing::debug!(
254 ast.optimized = with_gpudims.tree(),
255 node_count = with_gpudims.node_count(),
256 elapsed_ms = t_stage.elapsed().as_millis() as u64,
257 "after pm_add_gpudims"
258 );
259 if cfg!(debug_assertions) {
260 check_unroll_group("after_pm_add_gpudims", &with_gpudims);
261 }
262
263 let t_stage = std::time::Instant::now();
264 let with_loads = graph_rewrite(pm_add_loads(), with_gpudims, &mut ());
265 tracing::debug!(
266 ast.optimized = with_loads.tree(),
267 node_count = with_loads.node_count(),
268 elapsed_ms = t_stage.elapsed().as_millis() as u64,
269 "after pm_add_loads"
270 );
271 if cfg!(debug_assertions) {
272 check_unroll_group("after_pm_add_loads", &with_loads);
273 // Also check for any UNROLL or CONTRACT
274 for node in with_loads.toposort() {
275 if let morok_ir::Op::Unroll { src, unroll_axes, .. } = node.op() {
276 tracing::error!(
277 id = node.id,
278 src_op = src.op().as_ref(),
279 axes = ?unroll_axes,
280 "BEFORE devectorize: found UNROLL!"
281 );
282 }
283 if let morok_ir::Op::Contract { src, upcast_ranges, .. } = node.op() {
284 tracing::error!(
285 id = node.id,
286 src_op = src.op().as_ref(),
287 axes = ?upcast_ranges,
288 "BEFORE devectorize: found CONTRACT!"
289 );
290 }
291 }
292 }
293
294 // ALU devectorization happens inside devectorize() Phase 1, alongside expand_index
295 // and full symbolic (including gep_pushing). This matches Tinygrad's structure where
296 // no_vectorized_alu runs in the same pass as load_store_folding (step 14).
297 // Previously, an isolated pass here combined no_vectorized_alu + gep_pushing without
298 // load/store folding, causing graph explosion on wide VECTORIZE nodes (VECTORIZE(135)).
299 // Tinygrad Stage 14: devectorize — single combined pass handles ALL devectorization
300 // including bool ALU (via no_vectorized_alu). No separate pm_bool_devectorize or
301 // pm_reduce_devectorize passes — matching Tinygrad's pipeline exactly.
302 let t_stage = std::time::Instant::now();
303 let devectorized = crate::devectorize::devectorize(&with_loads);
304 tracing::debug!(
305 ast.optimized = devectorized.tree(),
306 node_count = devectorized.node_count(),
307 elapsed_ms = t_stage.elapsed().as_millis() as u64,
308 "after devectorize"
309 );
310 check_unroll_group("after_devectorize", &devectorized);
311
312 // Tinygrad Stage 15: pm_lower_index_dtype + load_store_indexing + gep_pushing
313 let t_stage = std::time::Instant::now();
314 static PM_LOWER_COMBINED: LazyLock<crate::TypedPatternMatcher> = LazyLock::new(|| {
315 crate::symbolic::pm_lower_index_dtype()
316 + crate::devectorize::load_store_indexing_patterns()
317 + gep_pushing_patterns()
318 });
319 let with_lowered_idx = graph_rewrite(&*PM_LOWER_COMBINED, devectorized, &mut ());
320 tracing::debug!(
321 ast.optimized = with_lowered_idx.tree(),
322 node_count = with_lowered_idx.node_count(),
323 elapsed_ms = t_stage.elapsed().as_millis() as u64,
324 "after pm_lower_index_dtype"
325 );
326 check_unroll_group("after_pm_lower_index_dtype", &with_lowered_idx);
327
328 // Tinygrad: symbolic (step 16) — full symbolic (includes gep_pushing, div_and_mod, etc.)
329 let t_stage = std::time::Instant::now();
330 static POST_INDEX_SYM: LazyLock<crate::TypedPatternMatcher> = LazyLock::new(|| symbolic().clone());
331 let with_lowered_idx = graph_rewrite(&*POST_INDEX_SYM, with_lowered_idx, &mut ());
332 tracing::debug!(
333 ast.optimized = with_lowered_idx.tree(),
334 node_count = with_lowered_idx.node_count(),
335 elapsed_ms = t_stage.elapsed().as_millis() as u64,
336 "after post-index symbolic"
337 );
338
339 // =========================================================================
340 // Stage 18-19: Decompositions + Render (Tinygrad: pm_decomp + pm_render in one pass)
341 // =========================================================================
342 let t_stage = std::time::Instant::now();
343 static PM_FINAL: LazyLock<crate::TypedPatternMatcher> =
344 LazyLock::new(|| symbolic_simple() + get_late_rewrite_patterns() + pm_render());
345 let rendered = graph_rewrite(&*PM_FINAL, with_lowered_idx, &mut ());
346 tracing::debug!(
347 ast.optimized = rendered.tree(),
348 node_count = rendered.node_count(),
349 elapsed_ms = t_stage.elapsed().as_millis() as u64,
350 "Stage 18-19: after pm_decomp + pm_render"
351 );
352
353 // Merge sibling ENDs that share the same reduce ranges.
354 // pm_decomp+pm_render can create new sibling ENDs (e.g. by rewriting computations
355 // inside an END differently per vector lane). merge_reduce_ends ran earlier in
356 // pm_reduce but only caught ENDs that existed at that point.
357 let t_merge = std::time::Instant::now();
358 let rendered = crate::devectorize::merge_sibling_ends(&rendered);
359 tracing::debug!(
360 ast.optimized = rendered.tree(),
361 node_count = rendered.node_count(),
362 elapsed_ms = t_merge.elapsed().as_millis() as u64,
363 "after merge_sibling_ends"
364 );
365
366 // FP8 float decomposition: promote FP8 computation to Float16 via bitwise conversion.
367 // Uses graph_rewrite_with_bpm: STORE pattern in bpm (sees ORIGINAL children to detect
368 // FP8 buffer ptrs), all other patterns in pm (sees OPTIMIZED children).
369 // Run once per FP8 type. Tinygrad: codegen/__init__.py:97-99
370 let t_stage = std::time::Instant::now();
371 let fp8_pm = pm_float_decomp();
372 let fp8_bpm = pm_float_decomp_store();
373 let mut fp8_decomposed = rendered;
374 for (fr, to) in [
375 (morok_dtype::ScalarDType::FP8E5M2, morok_dtype::ScalarDType::Float16),
376 (morok_dtype::ScalarDType::FP8E4M3, morok_dtype::ScalarDType::Float16),
377 ] {
378 let mut ctx = Fp8DecompCtx { from: fr, to };
379 fp8_decomposed = morok_ir::rewrite::graph_rewrite_with_bpm(&fp8_pm, &fp8_bpm, fp8_decomposed, &mut ctx);
380 }
381 tracing::debug!(
382 ast.optimized = fp8_decomposed.tree(),
383 node_count = fp8_decomposed.node_count(),
384 elapsed_ms = t_stage.elapsed().as_millis() as u64,
385 "after pm_float_decomp"
386 );
387
388 let t_stage = std::time::Instant::now();
389 let bs = graph_rewrite(bool_storage_patterns(), fp8_decomposed, &mut ());
390 tracing::debug!(
391 ast.optimized = bs.tree(),
392 node_count = bs.node_count(),
393 elapsed_ms = t_stage.elapsed().as_millis() as u64,
394 "after bool_storage_pattern"
395 );
396
397 // Re-attach metadata (e.g., KernelInfo) that was lost during graph rewrites
398 match saved_metadata {
399 Some(meta) => bs.with_metadata_raw(meta),
400 None => bs,
401 }
402}
403
404/// Late rewrite patterns for algebraic decompositions.
405///
406/// Based on Tinygrad's `get_late_rewrite_patterns` (decompositions.py:438-480).
407///
408/// Returns patterns for:
409/// - MULACC (FMA): `a*b+c → MulAcc(a,b,c)` for float types
410/// - MOD → AND: `x % 2^n → x & (2^n-1)` for power-of-two modulus
411/// - MUL → SHL: `x * 2^n → x << n` for power-of-two multiplier
412/// - NEG from MUL: `x * -1 → NEG(x)`
413/// - Fast integer division (magic number multiplication)
414fn get_late_rewrite_patterns() -> &'static crate::TypedPatternMatcher {
415 // All current backends support MAX and SQRT natively (LLVM, CUDA, Metal).
416 // When we add backends that lack support, this should take a capability set
417 // (like Tinygrad's `ops: tuple[Ops, ...]`) and conditionally include patterns.
418 static CACHED: LazyLock<crate::TypedPatternMatcher> = LazyLock::new(|| {
419 pm_fma_decomposition()
420 + pm_erf_decomposition()
421 + pm_mod_to_and()
422 + pm_mul_to_shl()
423 + pm_div_to_shr()
424 + pm_fdiv_to_mul()
425 + pm_neg_from_mul()
426 + pm_demorgan()
427 + pm_shl_add_to_mulacc()
428 + pm_threefry_decomp()
429 + pm_comparison_negations()
430 + crate::symbolic::fast_division_patterns()
431 + pm_mod_to_idiv()
432 });
433 &CACHED
434}
435
436/// MOD → IDIV decomposition (Tinygrad decompositions.py:457).
437///
438/// `x % d → x - d*(x//d)` for non-power-of-2 constant divisors.
439/// Runs AFTER fast_division_patterns so the resulting IDIV gets decomposed
440/// to magic-number multiplication. Without this, standalone MOD nodes
441/// for non-power-of-2 divisors survive to codegen unlowered.
442fn pm_mod_to_idiv() -> &'static crate::TypedPatternMatcher {
443 crate::cached_patterns! {
444 Mod(x, d @const(d_val))
445 if x.dtype().is_int()
446 && matches!(d_val.try_int(), Some(v) if v > 1 && !((v as u64).is_power_of_two()))
447 => {
448 // x % d → x - d * (x // d)
449 let div = x.idiv(d);
450 let mul = d.try_mul(&div).ok()?;
451 x.try_sub(&mul).ok()
452 },
453 }
454}
455
456/// Apply per-kernel pre-optimization passes.
457///
458/// These stages run BEFORE heuristic/beam optimization, per-kernel
459/// (Tinygrad: inside `full_rewrite_to_sink()`, codegen/__init__.py:28-51).
460///
461/// Stages:
462/// 0. Movement ops + syntactic sugar (`pm_mops + pm_syntactic_sugar`, bottom-up)
463/// 1. Load collapse (`pm_load_collapse`)
464/// 2. Split ranges + flatten (`pm_split_ranges + pm_flatten_range`)
465/// 3. Symbolic + flatten (`sym + pm_flatten_range`)
466/// 4. Simplify ranges (`pm_simplify_ranges`)
467///
468/// Called by both heuristic and beam search paths.
469#[tracing::instrument(skip_all)]
470pub fn apply_pre_optimization(ast: Arc<morok_ir::UOp>) -> Arc<morok_ir::UOp> {
471 tracing::debug!(ast.initial = ast.tree(), node_count = ast.node_count(), "kernel initial");
472
473 use crate::rangeify::transforms::SplitRangesContext;
474
475 let t_stage = std::time::Instant::now();
476 use crate::rangeify::patterns::{movement_op_patterns, pm_syntactic_sugar};
477 use crate::rewrite::graph_rewrite_bottom_up;
478 static PM_EARLY_MOPS: LazyLock<crate::TypedPatternMatcher> =
479 LazyLock::new(|| movement_op_patterns() + pm_syntactic_sugar());
480 let mut sink = graph_rewrite_bottom_up(&*PM_EARLY_MOPS, ast, &mut ());
481 tracing::debug!(
482 ast.pre = sink.tree(),
483 node_count = sink.node_count(),
484 elapsed_ms = t_stage.elapsed().as_millis() as u64,
485 "pre-opt: movement ops + syntactic sugar complete"
486 );
487
488 let t_stage = std::time::Instant::now();
489 sink = graph_rewrite(pm_load_collapse(), sink, &mut ());
490 tracing::debug!(
491 ast.pre = sink.tree(),
492 node_count = sink.node_count(),
493 elapsed_ms = t_stage.elapsed().as_millis() as u64,
494 "pre-opt: load collapse complete"
495 );
496
497 let t_stage = std::time::Instant::now();
498 let mut split_ctx = SplitRangesContext::default();
499 sink = graph_rewrite(&pm_split_ranges(), sink, &mut split_ctx);
500 sink = graph_rewrite(pm_flatten_range(), sink, &mut ());
501 tracing::debug!(
502 ast.pre = sink.tree(),
503 node_count = sink.node_count(),
504 elapsed_ms = t_stage.elapsed().as_millis() as u64,
505 "pre-opt: split ranges complete"
506 );
507
508 let t_stage = std::time::Instant::now();
509 // Tinygrad: sym + pm_flatten_range (pre-opt uses full sym tier)
510 static PM_SYM_FLATTEN: LazyLock<crate::TypedPatternMatcher> = LazyLock::new(|| sym().clone() + pm_flatten_range());
511 sink = graph_rewrite(&*PM_SYM_FLATTEN, sink, &mut ());
512 tracing::debug!(
513 ast.pre = sink.tree(),
514 node_count = sink.node_count(),
515 elapsed_ms = t_stage.elapsed().as_millis() as u64,
516 "pre-opt: symbolic + flatten complete"
517 );
518
519 let t_stage = std::time::Instant::now();
520 static PM_SIMPLIFY_FLATTEN: LazyLock<crate::TypedPatternMatcher> =
521 LazyLock::new(|| pm_flatten_range() + pm_simplify_ranges());
522 sink = graph_rewrite(&*PM_SIMPLIFY_FLATTEN, sink, &mut ());
523 tracing::debug!(
524 ast.pre = sink.tree(),
525 node_count = sink.node_count(),
526 elapsed_ms = t_stage.elapsed().as_millis() as u64,
527 "pre-opt: simplify ranges complete"
528 );
529
530 sink
531}
532
533/// Apply optimizations with explicit configuration.
534///
535/// Use this when you need explicit control over the optimization settings.
536///
537/// Note: For beam search strategy, this falls back to heuristics because
538/// beam search requires a `compile_and_time` function from the runtime.
539/// Use `optimize_kernel_beam()` for actual beam search optimization.
540pub fn optimize_kernel_with_config(
541 ast: Arc<morok_ir::UOp>,
542 renderer: &Renderer,
543 config: &OptimizerConfig,
544) -> Arc<morok_ir::UOp> {
545 // Pre-optimization: per-kernel stages (Tinygrad: full_rewrite_to_sink)
546 let pre_optimized = apply_pre_optimization(ast);
547
548 let optimized = match config.strategy {
549 OptStrategy::None => pre_optimized, // No heuristic optimization, but post-optimization still needed
550 OptStrategy::Heuristic => optimize_heuristic(pre_optimized, renderer, &config.heuristics),
551 OptStrategy::Beam { .. } => {
552 // Beam search requires a compile_and_time function.
553 // Use optimize_kernel_beam() for actual beam search.
554 // Fall back to heuristics for the simple API.
555 optimize_heuristic(pre_optimized, renderer, &config.heuristics)
556 }
557 };
558
559 // apply_post_optimization contains correctness transforms (pm_add_loads wraps INDEX
560 // with LOAD for arithmetic ops) and must run even when optimizations are disabled.
561 // Pass the renderer to enable GPU dimension injection for GPU backends.
562
563 apply_post_optimization_with_renderer(optimized, Some(renderer))
564}
565
566/// Apply optimizations with explicit strategy selection (legacy API).
567///
568/// Prefer `optimize_kernel_with_config` for new code.
569pub fn optimize_kernel_with_strategy(
570 ast: Arc<morok_ir::UOp>,
571 renderer: &Renderer,
572 strategy: OptStrategy,
573) -> Arc<morok_ir::UOp> {
574 let config = OptimizerConfig { strategy, ..Default::default() };
575 optimize_kernel_with_config(ast, renderer, &config)
576}
577
578/// Apply beam search optimization with custom timing function.
579///
580/// This is the primary entry point for beam search auto-tuning. It requires
581/// a `compile_and_time` function that compiles a scheduler state and returns
582/// its execution timing.
583///
584/// # Arguments
585///
586/// * `ast` - The kernel AST to optimize
587/// * `renderer` - Backend capabilities descriptor
588/// * `config` - Beam search configuration
589/// * `compile_and_time` - Function to compile and time a scheduler
590///
591/// # Returns
592///
593/// Result containing `BeamResult` with optimized scheduler and metrics.
594///
595/// # Example
596///
597/// ```ignore
598/// use morok_schedule::optimizer::{optimize_kernel_beam, BeamConfig, Renderer};
599/// use morok_runtime::{BenchmarkConfig, benchmark_kernel};
600///
601/// let config = BeamConfig::from_env();
602/// let renderer = Renderer::cpu();
603///
604/// let compile_and_time = |scheduler: &Scheduler| -> Option<Duration> {
605/// let ast = scheduler.get_optimized_ast(None);
606/// let kernel = compile_kernel(&ast)?;
607/// let result = benchmark_kernel(&kernel, &buffers, &vars, &bench_config).ok()?;
608/// Some(result.min)
609/// };
610///
611/// let result = optimize_kernel_beam(ast, &renderer, &config, compile_and_time)?;
612/// let optimized_ast = result.scheduler.get_optimized_ast(None);
613/// ```
614pub fn optimize_kernel_beam<F>(
615 ast: Arc<morok_ir::UOp>,
616 renderer: &Renderer,
617 config: &BeamConfig,
618 compile_and_time: F,
619) -> Result<BeamResult, error::OptError>
620where
621 F: Fn(&Scheduler) -> Option<std::time::Duration> + Sync,
622{
623 // Step 0: Per-kernel pre-optimization (Tinygrad: full_rewrite_to_sink)
624 let pre_optimized = apply_pre_optimization(ast);
625
626 // Step 1: Create scheduler (AST already simplified by apply_pre_optimization Stage 3)
627 let mut scheduler = Scheduler::new(pre_optimized, renderer.clone());
628
629 // Step 2: Convert loops to global (for GPU parallelization)
630 let _ = scheduler.convert_loop_to_global();
631
632 // Step 4: Run beam search (with caching)
633 beam::beam_search_cached(scheduler, config, compile_and_time)
634}
635
636/// Create a scheduler ready for optimization without applying any opts.
637///
638/// This is useful when you want to manually control the optimization process
639/// or use beam search with custom logic.
640///
641/// # Arguments
642///
643/// * `ast` - The kernel AST
644/// * `renderer` - Backend capabilities descriptor
645///
646/// # Returns
647///
648/// A `Scheduler` with loops converted to globals (if applicable).
649pub fn prepare_scheduler(ast: Arc<morok_ir::UOp>, renderer: &Renderer) -> Scheduler {
650 let pre_optimized = apply_pre_optimization(ast);
651 let mut scheduler = Scheduler::new(pre_optimized, renderer.clone());
652 let _ = scheduler.convert_loop_to_global(); // GPU: LOOP→GLOBAL
653 // Note: Don't apply threading here - let beam search explore THREAD actions naturally.
654 // Heuristics apply threading via hand_coded_optimizations() with config.thread_count.
655 scheduler
656}
657
658/// Apply heuristic-based optimizations.
659fn optimize_heuristic(ast: Arc<morok_ir::UOp>, renderer: &Renderer, config: &HeuristicsConfig) -> Arc<morok_ir::UOp> {
660 // Step 1: Create scheduler (AST already simplified by apply_pre_optimization Stage 3)
661 let mut scheduler = Scheduler::new(ast, renderer.clone());
662
663 // Step 3: Convert axes for parallelization/vectorization
664 let _ = scheduler.convert_loop_to_global(); // GPU: LOOP→GLOBAL
665 let _ = scheduler.convert_outer_to_loop(); // CPU: OUTER→LOOP (enables UPCAST)
666
667 // Step 4: Apply hand-coded heuristics with config
668 heuristics::hand_coded_optimizations(&mut scheduler, config);
669
670 // Step 5: Extract optimized AST
671 scheduler.get_optimized_ast(None)
672}