morok-schedule 0.1.0-alpha.2

Optimization passes and pattern engine for the Morok ML compiler
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
//! Kernel optimization layer for morok-schedule.
//!
//! This module implements hardware-aware kernel optimization based on Tinygrad's approach.
//! It provides a `Scheduler` that applies optimization primitives (OptOps) to transform
//! kernel execution for better performance on specific backends.
//!
//! # Architecture
//!
//! The optimization process follows this flow:
//!
//! 1. **Initialization**: Create `Scheduler` from UOp AST + `Renderer` (backend capabilities)
//! 2. **Initial Transform**: Convert eligible LOOP axes to GLOBAL (parallelization)
//! 3. **Optimization**: Apply `Opt` operations via `apply_opt()`
//!    - UPCAST: Vectorization (SIMD)
//!    - LOCAL: GPU workgroup dimensions (shared memory)
//!    - UNROLL: Loop unrolling for reductions
//!    - GROUP: Two-stage reductions with synchronization
//!    - TC: Tensor core acceleration
//!    - PADTO, SWAP, THREAD, NOLOCALS: Layout and configuration
//! 4. **Finalization**: Extract optimized AST with `get_optimized_ast()`
//!
//! # Optimization Strategies
//!
//! - **Hand-coded heuristics** (`heuristics` module): Fast, reasonable performance
//! - **Beam search** (`beam` module, optional): Slow, ML-quality performance
//!
//! # Example
//!
//! ```ignore
//! use morok_schedule::optimizer::{Scheduler, Renderer, Opt, OptOps};
//!
//! // Create scheduler with CUDA backend
//! let renderer = Renderer::cuda();
//! let mut scheduler = Scheduler::new(kernel_ast, renderer);
//!
//! // Apply optimizations
//! scheduler.convert_loop_to_global();
//! scheduler.apply_opt(Opt::upcast(0, 4), true)?; // Vectorize axis 0 by 4
//! scheduler.apply_opt(Opt::local(1, 16), true)?; // Local memory for axis 1
//!
//! // Get optimized kernel
//! let optimized_ast = scheduler.get_optimized_ast(None);
//! ```

pub mod beam;
pub mod config;
pub mod error;
pub mod heuristics;
pub mod kernel_info;
pub mod opts;
pub mod renderer;
pub mod scheduler;
pub mod tc;
pub mod types;

// Re-exports
pub use beam::{BeamResult, beam_search, beam_search_cached, beam_search_with_timeout, clear_cache, replay_opts};
pub use config::{BeamConfig, HeuristicsConfig, OptStrategy, OptimizerConfig, TcOpt as TcOptLevel, TcSelect, TcUsage};
pub use error::OptError;
pub use heuristics::hand_coded_optimizations;
pub use kernel_info::KernelInfo;
pub use opts::apply_opt;
pub use renderer::{Renderer, TcOpt, TensorCore};
pub use scheduler::Scheduler;
#[cfg(test)]
pub use scheduler::clear_kernel_name_counts;
pub use types::{AxisType, Opt, OptArg, OptOps};

use crate::devectorize::{
    Fp8DecompCtx, bool_storage_patterns, pm_float_decomp, pm_float_decomp_store, pm_reduce, pm_render,
    pm_wmma_accumulate,
};
use crate::gpudims::pm_add_gpudims;
// pm_linearize_multi_index removed: Tinygrad keeps multi-index INDEX through the pipeline.
// Codegen backends compute flat addresses at render time.
use crate::rangeify::patterns::{
    pm_add_loads, pm_comparison_negations, pm_demorgan, pm_div_to_shr, pm_erf_decomposition, pm_fdiv_to_mul,
    pm_fma_decomposition, pm_load_collapse, pm_mod_to_and, pm_mul_to_shl, pm_neg_from_mul, pm_shl_add_to_mulacc,
    pm_threefry_decomp, rangeify_codegen_with_kernel_ctx,
};
use crate::rangeify::pm_add_buffers_local_patterns;
use crate::rangeify::transforms::{pm_flatten_range, pm_simplify_ranges, pm_split_ranges};
use crate::rewrite::graph_rewrite;
use crate::symbolic::patterns::{gep_pushing_patterns, sym, symbolic, symbolic_simple};
use std::sync::{Arc, LazyLock};

/// Apply optimizations to a kernel AST.
///
/// This is the main entry point for optimization in the tensor pipeline.
/// Uses environment variables for configuration (see `OptimizerConfig::from_env`).
///
/// # Pipeline
///
/// 1. **Symbolic simplification** - Constant folding, identities, DCE
/// 2. **Loop→Global conversion** - Enable GPU parallelization
/// 3. **Hand-coded heuristics** - Vectorization, unrolling, tiling
///
/// # Arguments
///
/// * `ast` - The kernel AST (inner AST from KERNEL op)
/// * `renderer` - Backend capabilities descriptor
///
/// # Returns
///
/// Optimized AST with transformations applied.
///
/// # Environment Variables
///
/// * `MOROK_NOOPT=1` - Disable all optimizations (for debugging)
/// * `MOROK_BEAM=N` - Use beam search with width N (future)
pub fn optimize_kernel(ast: Arc<morok_ir::UOp>, renderer: &Renderer) -> Arc<morok_ir::UOp> {
    optimize_kernel_with_config(ast, renderer, &OptimizerConfig::from_env())
}

/// Apply post-optimization passes to kernel AST.
///
/// These passes run AFTER heuristic/beam optimization and BEFORE codegen:
/// - pm_add_loads: Extract LOAD ops from INDEX
/// - pre_expand: Convert Range(Unroll/Upcast) → UNROLL, expand operations
/// - pm_add_gpudims (GPU only): Convert GLOBAL/LOCAL RANGE to SPECIAL thread indices
/// - devectorize: Combined pass (sym + devec + load_store_folding + correct_load_store + indexing)
/// - bool_storage_patterns: Convert bool LOAD/STORE to uint8
///
/// NOTE: We do NOT apply FMA decomposition (a*b+c → MulAcc). Following Tinygrad's
/// approach, we let LLVM's optimizer fuse MUL+ADD into FMA when beneficial.
///
/// # Arguments
///
/// * `ast` - The kernel AST to optimize
///
/// Called by both heuristic and beam search paths for consistent behavior.
/// For GPU pipelines, use `apply_post_optimization_with_renderer` to enable GPU dimension injection.
#[tracing::instrument(skip_all)]
pub fn apply_post_optimization(ast: Arc<morok_ir::UOp>) -> Arc<morok_ir::UOp> {
    apply_post_optimization_with_renderer(ast, None)
}

/// Apply post-optimization passes with renderer context.
///
/// Same as `apply_post_optimization` but accepts an optional renderer for GPU-specific passes.
/// When a renderer with GPU capabilities (has_local) is provided, `pm_add_gpudims` is applied
/// to convert GLOBAL/LOCAL RANGE operations to SPECIAL thread indices.
///
/// # Arguments
///
/// * `ast` - The kernel AST to optimize
/// * `renderer` - Optional renderer for GPU dimension injection
#[tracing::instrument(skip_all)]
pub fn apply_post_optimization_with_renderer(
    ast: Arc<morok_ir::UOp>,
    renderer: Option<&Renderer>,
) -> Arc<morok_ir::UOp> {
    // Save metadata before graph_rewrite destroys it (e.g., KernelInfo with kernel name)
    let saved_metadata = ast.metadata_raw();

    tracing::debug!(ast.initial = ast.tree(), node_count = ast.node_count(), "kernel initial");

    // Tinygrad keeps multi-index INDEX through the pipeline — no linearization here.
    // Codegen backends compute flat addresses at render time via render_linearize_multi_index.

    // =========================================================================
    // Stage 8: Post-opt symbolic + WHERE movement (Tinygrad: sym + pm_move_where_on_load)
    // This MUST run BEFORE expander to optimize conditionals before expansion.
    // =========================================================================
    let t_stage = std::time::Instant::now();
    // Tinygrad: sym + pm_move_where_on_load (pm_move_where_on_load only at this stage, not global)
    static POST_OPT_SYM: LazyLock<crate::TypedPatternMatcher> =
        LazyLock::new(|| sym().clone() + crate::symbolic::patterns::pm_move_where_on_load());
    let with_symbolic = graph_rewrite(&*POST_OPT_SYM, ast, &mut ());
    tracing::debug!(
        ast.optimized = with_symbolic.tree(),
        node_count = with_symbolic.node_count(),
        elapsed_ms = t_stage.elapsed().as_millis() as u64,
        "Stage 8: after post-opt symbolic"
    );

    // =========================================================================
    // Stage 9: Expander (Tinygrad: sym + pm_pre_expander + pm_group_for_reduce + expander)
    // =========================================================================
    // UNROLL expansion: Expand UNROLL ops to vectorized operations (Tinygrad expander.py)
    // CRITICAL: Must run BEFORE pm_reduce so that REDUCE sees its actual vectorized dtype.
    // In Tinygrad, expander runs first, then pm_reduce sees the expanded REDUCE with vec2 dtype.
    // This allows reduce_to_acc to create accumulators with the correct vector dtype.
    let t_stage = std::time::Instant::now();
    let expanded = crate::expand::pre_expand(&with_symbolic);
    tracing::debug!(
        ast.optimized = expanded.tree(),
        node_count = expanded.node_count(),
        elapsed_ms = t_stage.elapsed().as_millis() as u64,
        "Stage 9: after pre_expand"
    );

    // =========================================================================
    // Stage 10: Add local buffers (Tinygrad: pm_add_buffers_local + rangeify_codegen)
    // =========================================================================
    // Converts BUFFERIZE(Local) → DEFINE_LOCAL + STORE + LOAD for GROUP_REDUCE.
    // Also strips leftover CONTIGUOUS and NOOP nodes.
    // Must run AFTER expander (which creates BUFFERIZE_LOCAL) and BEFORE pm_reduce.
    //
    // CRITICAL: Combine pm_add_buffers_local + rangeify_codegen in a SINGLE pass
    // (like Tinygrad) to ensure CONTIGUOUS is stripped BEFORE bufferize_to_store
    // sees it. Otherwise CONTIGUOUS(BUFFER) becomes the STORE value directly,
    // which fails codegen because STORE expects a value, not a buffer pointer.
    // Helper closure: check for UNROLL(GROUP) in graph
    let check_unroll_group = |label: &str, root: &Arc<morok_ir::UOp>| {
        for node in root.toposort() {
            if let morok_ir::Op::Unroll { src, unroll_axes, .. } = node.op()
                && matches!(src.op(), morok_ir::Op::Group { .. })
            {
                tracing::error!(id = node.id, axes = ?unroll_axes, stage = label, "UNROLL(GROUP) found!");
            }
        }
    };

    let t_stage = std::time::Instant::now();
    let with_local_buffers = {
        let mut buf_ctx = crate::rangeify::KernelContext::new();
        static PM_LOCAL_BUF: LazyLock<crate::TypedPatternMatcher<crate::rangeify::KernelContext>> =
            LazyLock::new(|| pm_add_buffers_local_patterns() + rangeify_codegen_with_kernel_ctx());
        graph_rewrite(&*PM_LOCAL_BUF, expanded, &mut buf_ctx)
    };
    tracing::debug!(
        ast.optimized = with_local_buffers.tree(),
        node_count = with_local_buffers.node_count(),
        elapsed_ms = t_stage.elapsed().as_millis() as u64,
        "Stage 10: after add local buffers"
    );
    if cfg!(debug_assertions) {
        check_unroll_group("after_add_local_buffers", &with_local_buffers);
    }

    let t_stage = std::time::Instant::now();
    static PM_REDUCE_COMBINED: LazyLock<crate::TypedPatternMatcher<crate::devectorize::ReduceContext>> =
        LazyLock::new(|| pm_reduce() + pm_wmma_accumulate().with_context() + gep_pushing_patterns().with_context());
    let mut reduce_ctx = crate::devectorize::ReduceContext::default();
    let reduced = graph_rewrite(&*PM_REDUCE_COMBINED, with_local_buffers, &mut reduce_ctx);
    tracing::debug!(
        ast.optimized = reduced.tree(),
        node_count = reduced.node_count(),
        elapsed_ms = t_stage.elapsed().as_millis() as u64,
        "after pm_reduce"
    );
    if cfg!(debug_assertions) {
        check_unroll_group("after_pm_reduce", &reduced);
    }

    let t_stage = std::time::Instant::now();
    let with_gpudims = if let Some(ren) = renderer {
        if ren.has_local { graph_rewrite(&pm_add_gpudims(), reduced, &mut ren.clone()) } else { reduced }
    } else {
        reduced
    };
    tracing::debug!(
        ast.optimized = with_gpudims.tree(),
        node_count = with_gpudims.node_count(),
        elapsed_ms = t_stage.elapsed().as_millis() as u64,
        "after pm_add_gpudims"
    );
    if cfg!(debug_assertions) {
        check_unroll_group("after_pm_add_gpudims", &with_gpudims);
    }

    let t_stage = std::time::Instant::now();
    let with_loads = graph_rewrite(pm_add_loads(), with_gpudims, &mut ());
    tracing::debug!(
        ast.optimized = with_loads.tree(),
        node_count = with_loads.node_count(),
        elapsed_ms = t_stage.elapsed().as_millis() as u64,
        "after pm_add_loads"
    );
    if cfg!(debug_assertions) {
        check_unroll_group("after_pm_add_loads", &with_loads);
        // Also check for any UNROLL or CONTRACT
        for node in with_loads.toposort() {
            if let morok_ir::Op::Unroll { src, unroll_axes, .. } = node.op() {
                tracing::error!(
                    id = node.id,
                    src_op = src.op().as_ref(),
                    axes = ?unroll_axes,
                    "BEFORE devectorize: found UNROLL!"
                );
            }
            if let morok_ir::Op::Contract { src, upcast_ranges, .. } = node.op() {
                tracing::error!(
                    id = node.id,
                    src_op = src.op().as_ref(),
                    axes = ?upcast_ranges,
                    "BEFORE devectorize: found CONTRACT!"
                );
            }
        }
    }

    // ALU devectorization happens inside devectorize() Phase 1, alongside expand_index
    // and full symbolic (including gep_pushing). This matches Tinygrad's structure where
    // no_vectorized_alu runs in the same pass as load_store_folding (step 14).
    // Previously, an isolated pass here combined no_vectorized_alu + gep_pushing without
    // load/store folding, causing graph explosion on wide VECTORIZE nodes (VECTORIZE(135)).
    // Tinygrad Stage 14: devectorize — single combined pass handles ALL devectorization
    // including bool ALU (via no_vectorized_alu). No separate pm_bool_devectorize or
    // pm_reduce_devectorize passes — matching Tinygrad's pipeline exactly.
    let t_stage = std::time::Instant::now();
    let devectorized = crate::devectorize::devectorize(&with_loads);
    tracing::debug!(
        ast.optimized = devectorized.tree(),
        node_count = devectorized.node_count(),
        elapsed_ms = t_stage.elapsed().as_millis() as u64,
        "after devectorize"
    );
    check_unroll_group("after_devectorize", &devectorized);

    // Tinygrad Stage 15: pm_lower_index_dtype + load_store_indexing + gep_pushing
    let t_stage = std::time::Instant::now();
    static PM_LOWER_COMBINED: LazyLock<crate::TypedPatternMatcher> = LazyLock::new(|| {
        crate::symbolic::pm_lower_index_dtype()
            + crate::devectorize::load_store_indexing_patterns()
            + gep_pushing_patterns()
    });
    let with_lowered_idx = graph_rewrite(&*PM_LOWER_COMBINED, devectorized, &mut ());
    tracing::debug!(
        ast.optimized = with_lowered_idx.tree(),
        node_count = with_lowered_idx.node_count(),
        elapsed_ms = t_stage.elapsed().as_millis() as u64,
        "after pm_lower_index_dtype"
    );
    check_unroll_group("after_pm_lower_index_dtype", &with_lowered_idx);

    // Tinygrad: symbolic (step 16) — full symbolic (includes gep_pushing, div_and_mod, etc.)
    let t_stage = std::time::Instant::now();
    static POST_INDEX_SYM: LazyLock<crate::TypedPatternMatcher> = LazyLock::new(|| symbolic().clone());
    let with_lowered_idx = graph_rewrite(&*POST_INDEX_SYM, with_lowered_idx, &mut ());
    tracing::debug!(
        ast.optimized = with_lowered_idx.tree(),
        node_count = with_lowered_idx.node_count(),
        elapsed_ms = t_stage.elapsed().as_millis() as u64,
        "after post-index symbolic"
    );

    // =========================================================================
    // Stage 18-19: Decompositions + Render (Tinygrad: pm_decomp + pm_render in one pass)
    // =========================================================================
    let t_stage = std::time::Instant::now();
    static PM_FINAL: LazyLock<crate::TypedPatternMatcher> =
        LazyLock::new(|| symbolic_simple() + get_late_rewrite_patterns() + pm_render());
    let rendered = graph_rewrite(&*PM_FINAL, with_lowered_idx, &mut ());
    tracing::debug!(
        ast.optimized = rendered.tree(),
        node_count = rendered.node_count(),
        elapsed_ms = t_stage.elapsed().as_millis() as u64,
        "Stage 18-19: after pm_decomp + pm_render"
    );

    // Merge sibling ENDs that share the same reduce ranges.
    // pm_decomp+pm_render can create new sibling ENDs (e.g. by rewriting computations
    // inside an END differently per vector lane). merge_reduce_ends ran earlier in
    // pm_reduce but only caught ENDs that existed at that point.
    let t_merge = std::time::Instant::now();
    let rendered = crate::devectorize::merge_sibling_ends(&rendered);
    tracing::debug!(
        ast.optimized = rendered.tree(),
        node_count = rendered.node_count(),
        elapsed_ms = t_merge.elapsed().as_millis() as u64,
        "after merge_sibling_ends"
    );

    // FP8 float decomposition: promote FP8 computation to Float16 via bitwise conversion.
    // Uses graph_rewrite_with_bpm: STORE pattern in bpm (sees ORIGINAL children to detect
    // FP8 buffer ptrs), all other patterns in pm (sees OPTIMIZED children).
    // Run once per FP8 type. Tinygrad: codegen/__init__.py:97-99
    let t_stage = std::time::Instant::now();
    let fp8_pm = pm_float_decomp();
    let fp8_bpm = pm_float_decomp_store();
    let mut fp8_decomposed = rendered;
    for (fr, to) in [
        (morok_dtype::ScalarDType::FP8E5M2, morok_dtype::ScalarDType::Float16),
        (morok_dtype::ScalarDType::FP8E4M3, morok_dtype::ScalarDType::Float16),
    ] {
        let mut ctx = Fp8DecompCtx { from: fr, to };
        fp8_decomposed = morok_ir::rewrite::graph_rewrite_with_bpm(&fp8_pm, &fp8_bpm, fp8_decomposed, &mut ctx);
    }
    tracing::debug!(
        ast.optimized = fp8_decomposed.tree(),
        node_count = fp8_decomposed.node_count(),
        elapsed_ms = t_stage.elapsed().as_millis() as u64,
        "after pm_float_decomp"
    );

    let t_stage = std::time::Instant::now();
    let bs = graph_rewrite(bool_storage_patterns(), fp8_decomposed, &mut ());
    tracing::debug!(
        ast.optimized = bs.tree(),
        node_count = bs.node_count(),
        elapsed_ms = t_stage.elapsed().as_millis() as u64,
        "after bool_storage_pattern"
    );

    // Re-attach metadata (e.g., KernelInfo) that was lost during graph rewrites
    match saved_metadata {
        Some(meta) => bs.with_metadata_raw(meta),
        None => bs,
    }
}

/// Late rewrite patterns for algebraic decompositions.
///
/// Based on Tinygrad's `get_late_rewrite_patterns` (decompositions.py:438-480).
///
/// Returns patterns for:
/// - MULACC (FMA): `a*b+c → MulAcc(a,b,c)` for float types
/// - MOD → AND: `x % 2^n → x & (2^n-1)` for power-of-two modulus
/// - MUL → SHL: `x * 2^n → x << n` for power-of-two multiplier
/// - NEG from MUL: `x * -1 → NEG(x)`
/// - Fast integer division (magic number multiplication)
fn get_late_rewrite_patterns() -> &'static crate::TypedPatternMatcher {
    // All current backends support MAX and SQRT natively (LLVM, CUDA, Metal).
    // When we add backends that lack support, this should take a capability set
    // (like Tinygrad's `ops: tuple[Ops, ...]`) and conditionally include patterns.
    static CACHED: LazyLock<crate::TypedPatternMatcher> = LazyLock::new(|| {
        pm_fma_decomposition()
            + pm_erf_decomposition()
            + pm_mod_to_and()
            + pm_mul_to_shl()
            + pm_div_to_shr()
            + pm_fdiv_to_mul()
            + pm_neg_from_mul()
            + pm_demorgan()
            + pm_shl_add_to_mulacc()
            + pm_threefry_decomp()
            + pm_comparison_negations()
            + crate::symbolic::fast_division_patterns()
            + pm_mod_to_idiv()
    });
    &CACHED
}

/// MOD → IDIV decomposition (Tinygrad decompositions.py:457).
///
/// `x % d → x - d*(x//d)` for non-power-of-2 constant divisors.
/// Runs AFTER fast_division_patterns so the resulting IDIV gets decomposed
/// to magic-number multiplication. Without this, standalone MOD nodes
/// for non-power-of-2 divisors survive to codegen unlowered.
fn pm_mod_to_idiv() -> &'static crate::TypedPatternMatcher {
    crate::cached_patterns! {
        Mod(x, d @const(d_val))
            if x.dtype().is_int()
            && matches!(d_val.try_int(), Some(v) if v > 1 && !((v as u64).is_power_of_two()))
            => {
                // x % d → x - d * (x // d)
                let div = x.idiv(d);
                let mul = d.try_mul(&div).ok()?;
                x.try_sub(&mul).ok()
            },
    }
}

/// Apply per-kernel pre-optimization passes.
///
/// These stages run BEFORE heuristic/beam optimization, per-kernel
/// (Tinygrad: inside `full_rewrite_to_sink()`, codegen/__init__.py:28-51).
///
/// Stages:
/// 0. Movement ops + syntactic sugar (`pm_mops + pm_syntactic_sugar`, bottom-up)
/// 1. Load collapse (`pm_load_collapse`)
/// 2. Split ranges + flatten (`pm_split_ranges + pm_flatten_range`)
/// 3. Symbolic + flatten (`sym + pm_flatten_range`)
/// 4. Simplify ranges (`pm_simplify_ranges`)
///
/// Called by both heuristic and beam search paths.
#[tracing::instrument(skip_all)]
pub fn apply_pre_optimization(ast: Arc<morok_ir::UOp>) -> Arc<morok_ir::UOp> {
    tracing::debug!(ast.initial = ast.tree(), node_count = ast.node_count(), "kernel initial");

    use crate::rangeify::transforms::SplitRangesContext;

    let t_stage = std::time::Instant::now();
    use crate::rangeify::patterns::{movement_op_patterns, pm_syntactic_sugar};
    use crate::rewrite::graph_rewrite_bottom_up;
    static PM_EARLY_MOPS: LazyLock<crate::TypedPatternMatcher> =
        LazyLock::new(|| movement_op_patterns() + pm_syntactic_sugar());
    let mut sink = graph_rewrite_bottom_up(&*PM_EARLY_MOPS, ast, &mut ());
    tracing::debug!(
        ast.pre = sink.tree(),
        node_count = sink.node_count(),
        elapsed_ms = t_stage.elapsed().as_millis() as u64,
        "pre-opt: movement ops + syntactic sugar complete"
    );

    let t_stage = std::time::Instant::now();
    sink = graph_rewrite(pm_load_collapse(), sink, &mut ());
    tracing::debug!(
        ast.pre = sink.tree(),
        node_count = sink.node_count(),
        elapsed_ms = t_stage.elapsed().as_millis() as u64,
        "pre-opt: load collapse complete"
    );

    let t_stage = std::time::Instant::now();
    let mut split_ctx = SplitRangesContext::default();
    sink = graph_rewrite(&pm_split_ranges(), sink, &mut split_ctx);
    sink = graph_rewrite(pm_flatten_range(), sink, &mut ());
    tracing::debug!(
        ast.pre = sink.tree(),
        node_count = sink.node_count(),
        elapsed_ms = t_stage.elapsed().as_millis() as u64,
        "pre-opt: split ranges complete"
    );

    let t_stage = std::time::Instant::now();
    // Tinygrad: sym + pm_flatten_range (pre-opt uses full sym tier)
    static PM_SYM_FLATTEN: LazyLock<crate::TypedPatternMatcher> = LazyLock::new(|| sym().clone() + pm_flatten_range());
    sink = graph_rewrite(&*PM_SYM_FLATTEN, sink, &mut ());
    tracing::debug!(
        ast.pre = sink.tree(),
        node_count = sink.node_count(),
        elapsed_ms = t_stage.elapsed().as_millis() as u64,
        "pre-opt: symbolic + flatten complete"
    );

    let t_stage = std::time::Instant::now();
    static PM_SIMPLIFY_FLATTEN: LazyLock<crate::TypedPatternMatcher> =
        LazyLock::new(|| pm_flatten_range() + pm_simplify_ranges());
    sink = graph_rewrite(&*PM_SIMPLIFY_FLATTEN, sink, &mut ());
    tracing::debug!(
        ast.pre = sink.tree(),
        node_count = sink.node_count(),
        elapsed_ms = t_stage.elapsed().as_millis() as u64,
        "pre-opt: simplify ranges complete"
    );

    sink
}

/// Apply optimizations with explicit configuration.
///
/// Use this when you need explicit control over the optimization settings.
///
/// Note: For beam search strategy, this falls back to heuristics because
/// beam search requires a `compile_and_time` function from the runtime.
/// Use `optimize_kernel_beam()` for actual beam search optimization.
pub fn optimize_kernel_with_config(
    ast: Arc<morok_ir::UOp>,
    renderer: &Renderer,
    config: &OptimizerConfig,
) -> Arc<morok_ir::UOp> {
    // Pre-optimization: per-kernel stages (Tinygrad: full_rewrite_to_sink)
    let pre_optimized = apply_pre_optimization(ast);

    let optimized = match config.strategy {
        OptStrategy::None => pre_optimized, // No heuristic optimization, but post-optimization still needed
        OptStrategy::Heuristic => optimize_heuristic(pre_optimized, renderer, &config.heuristics),
        OptStrategy::Beam { .. } => {
            // Beam search requires a compile_and_time function.
            // Use optimize_kernel_beam() for actual beam search.
            // Fall back to heuristics for the simple API.
            optimize_heuristic(pre_optimized, renderer, &config.heuristics)
        }
    };

    // apply_post_optimization contains correctness transforms (pm_add_loads wraps INDEX
    // with LOAD for arithmetic ops) and must run even when optimizations are disabled.
    // Pass the renderer to enable GPU dimension injection for GPU backends.

    apply_post_optimization_with_renderer(optimized, Some(renderer))
}

/// Apply optimizations with explicit strategy selection (legacy API).
///
/// Prefer `optimize_kernel_with_config` for new code.
pub fn optimize_kernel_with_strategy(
    ast: Arc<morok_ir::UOp>,
    renderer: &Renderer,
    strategy: OptStrategy,
) -> Arc<morok_ir::UOp> {
    let config = OptimizerConfig { strategy, ..Default::default() };
    optimize_kernel_with_config(ast, renderer, &config)
}

/// Apply beam search optimization with custom timing function.
///
/// This is the primary entry point for beam search auto-tuning. It requires
/// a `compile_and_time` function that compiles a scheduler state and returns
/// its execution timing.
///
/// # Arguments
///
/// * `ast` - The kernel AST to optimize
/// * `renderer` - Backend capabilities descriptor
/// * `config` - Beam search configuration
/// * `compile_and_time` - Function to compile and time a scheduler
///
/// # Returns
///
/// Result containing `BeamResult` with optimized scheduler and metrics.
///
/// # Example
///
/// ```ignore
/// use morok_schedule::optimizer::{optimize_kernel_beam, BeamConfig, Renderer};
/// use morok_runtime::{BenchmarkConfig, benchmark_kernel};
///
/// let config = BeamConfig::from_env();
/// let renderer = Renderer::cpu();
///
/// let compile_and_time = |scheduler: &Scheduler| -> Option<Duration> {
///     let ast = scheduler.get_optimized_ast(None);
///     let kernel = compile_kernel(&ast)?;
///     let result = benchmark_kernel(&kernel, &buffers, &vars, &bench_config).ok()?;
///     Some(result.min)
/// };
///
/// let result = optimize_kernel_beam(ast, &renderer, &config, compile_and_time)?;
/// let optimized_ast = result.scheduler.get_optimized_ast(None);
/// ```
pub fn optimize_kernel_beam<F>(
    ast: Arc<morok_ir::UOp>,
    renderer: &Renderer,
    config: &BeamConfig,
    compile_and_time: F,
) -> Result<BeamResult, error::OptError>
where
    F: Fn(&Scheduler) -> Option<std::time::Duration> + Sync,
{
    // Step 0: Per-kernel pre-optimization (Tinygrad: full_rewrite_to_sink)
    let pre_optimized = apply_pre_optimization(ast);

    // Step 1: Create scheduler (AST already simplified by apply_pre_optimization Stage 3)
    let mut scheduler = Scheduler::new(pre_optimized, renderer.clone());

    // Step 2: Convert loops to global (for GPU parallelization)
    let _ = scheduler.convert_loop_to_global();

    // Step 4: Run beam search (with caching)
    beam::beam_search_cached(scheduler, config, compile_and_time)
}

/// Create a scheduler ready for optimization without applying any opts.
///
/// This is useful when you want to manually control the optimization process
/// or use beam search with custom logic.
///
/// # Arguments
///
/// * `ast` - The kernel AST
/// * `renderer` - Backend capabilities descriptor
///
/// # Returns
///
/// A `Scheduler` with loops converted to globals (if applicable).
pub fn prepare_scheduler(ast: Arc<morok_ir::UOp>, renderer: &Renderer) -> Scheduler {
    let pre_optimized = apply_pre_optimization(ast);
    let mut scheduler = Scheduler::new(pre_optimized, renderer.clone());
    let _ = scheduler.convert_loop_to_global(); // GPU: LOOP→GLOBAL
    // Note: Don't apply threading here - let beam search explore THREAD actions naturally.
    // Heuristics apply threading via hand_coded_optimizations() with config.thread_count.
    scheduler
}

/// Apply heuristic-based optimizations.
fn optimize_heuristic(ast: Arc<morok_ir::UOp>, renderer: &Renderer, config: &HeuristicsConfig) -> Arc<morok_ir::UOp> {
    // Step 1: Create scheduler (AST already simplified by apply_pre_optimization Stage 3)
    let mut scheduler = Scheduler::new(ast, renderer.clone());

    // Step 3: Convert axes for parallelization/vectorization
    let _ = scheduler.convert_loop_to_global(); // GPU: LOOP→GLOBAL
    let _ = scheduler.convert_outer_to_loop(); // CPU: OUTER→LOOP (enables UPCAST)

    // Step 4: Apply hand-coded heuristics with config
    heuristics::hand_coded_optimizations(&mut scheduler, config);

    // Step 5: Extract optimized AST
    scheduler.get_optimized_ast(None)
}