Skip to main content

morok_schedule/rangeify/
transforms.rs

1//! Consolidated transformation functions for rangeify.
2//!
3//! This module contains:
4//! - Main `rangeify()` entry point
5//! - Movement op → BUFFERIZE+INDEX transformation helpers
6//! - BUFFERIZE → STORE conversion
7//! - Reduction simplifications (reduce_unparented, reduce_collapse)
8//! - Range flattening (flatten_range_impl)
9//! - Cycle detection (find_bufs)
10//!
11//! Consolidated from: transform.rs, bufferize_to_store.rs, reduce_simplify.rs,
12//! flatten_range.rs, cycle_detection.rs
13
14use std::collections::{HashMap, HashSet};
15use std::sync::Arc;
16
17use super::context::RangeifyContext;
18use super::indexing::IndexingContext;
19use super::kernel::KernelContext;
20use morok_ir::shape::Shape;
21use morok_ir::{AddrSpace, AxisType, BufferizeOpts, ConstValue, DType, Op, UOp, UOpKey};
22use smallvec::{SmallVec, smallvec};
23
24// ============================================================================
25// ADD_TAGS — Tinygrad rangeify.py:542-555
26// ============================================================================
27
28/// Context for the add_tags pass.
29pub struct AddTagsCtx {
30    /// Sequential list of tagged UOps (index = tag value).
31    pub uop_list: Vec<Arc<UOp>>,
32    /// UOps excluded from tagging (e.g., nodes inside KERNEL/CALL).
33    excluded: HashSet<UOpKey>,
34}
35
36impl Default for AddTagsCtx {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42impl AddTagsCtx {
43    pub fn new() -> Self {
44        Self { uop_list: Vec::new(), excluded: HashSet::new() }
45    }
46}
47
48/// Ops that should NOT be tagged (Tinygrad rangeify.py:552-553).
49/// MStack/MSelect are handled separately with conditional logic.
50/// Note: Tinygrad also excludes LUNIQUE — Morok uses counter-based local IDs instead.
51fn should_skip_tag(op: &Op) -> bool {
52    matches!(
53        op,
54        Op::Param { .. }
55            | Op::Const(_)
56            | Op::Device(_)
57            | Op::Unique(_)
58            | Op::DefineVar { .. }
59            | Op::Bind { .. }
60            | Op::End { .. }
61            | Op::Range { .. }
62    ) || op.is_movement()
63}
64
65/// Create the add_tags pattern matcher (Tinygrad rangeify.py:550-555).
66///
67/// Assigns sequential integer tags `[i]` to each taggable UOp. Tags track which
68/// original tensor UOps map to which final kernel outputs through the pipeline.
69pub fn add_tags_patterns() -> crate::TypedPatternMatcher<AddTagsCtx> {
70    crate::patterns! {
71        @context AddTagsCtx;
72        // Wildcard: handles all ops, applies tag logic per Tinygrad rangeify.py:542-554
73        x => {
74            if x.tag().is_some() || ctx.excluded.contains(&UOpKey(x.clone())) { return None; }
75            // Kernel/Call: exclude entire subgraph from tagging (Tinygrad line 544-546)
76            if let Op::Kernel { ast, .. } = x.op() {
77                for u in ast.toposort() {
78                    ctx.excluded.insert(UOpKey(u));
79                }
80            }
81            if should_skip_tag(x.op()) { return None; }
82            // Index-typed scalars are not tagged (Tinygrad line 547)
83            if x.dtype().base() == morok_dtype::ScalarDType::Index { return None; }
84            // MStack/MSelect: only tag if NOT all sources are PARAM (Tinygrad line 554)
85            if matches!(x.op(), Op::MStack { .. } | Op::MSelect { .. })
86                && x.op().sources().iter().all(|s| matches!(s.op(), Op::Param { .. }))
87            {
88                return None;
89            }
90            ctx.uop_list.push(x.clone());
91            Some(x.with_tag(smallvec![ctx.uop_list.len() - 1]))
92        },
93    }
94}
95
96// ============================================================================
97// PUBLIC API
98// ============================================================================
99
100/// Main rangeify transformation entry point.
101///
102/// Converts movement operations (RESHAPE, PERMUTE, EXPAND, PAD, SHRINK, FLIP)
103/// into BUFFERIZE + INDEX operations with explicit loop ranges.
104pub fn rangeify(
105    sink: Arc<UOp>,
106    pcontig_config: Option<&super::kernel::PcontigConfig>,
107) -> morok_ir::Result<(Arc<UOp>, RangeifyContext)> {
108    let result = rangeify_with_map(sink, pcontig_config)?;
109    Ok((result.sink, result.context))
110}
111
112/// Result of rangeify transformation.
113pub struct RangeifyResult {
114    /// The transformed sink node
115    pub sink: Arc<UOp>,
116    /// Context with range information
117    pub context: RangeifyContext,
118    /// Tagged UOps from add_tags pass (index = tag value).
119    /// Used for tag-based becomes_map construction (Tinygrad rangeify.py:614-619).
120    pub uop_list: Vec<Arc<UOp>>,
121}
122
123/// Main rangeify transformation entry point with becomes_map tracking.
124///
125/// Like `rangeify`, but also returns a `becomes_map` that tracks which
126/// original nodes were transformed. This is essential for global graph
127/// coordination when multiple tensors share subgraphs.
128///
129/// # Pipeline (Tinygrad-aligned)
130///
131/// The pipeline follows Tinygrad's structure from codegen/__init__.py:
132///
133/// **Stage 0**: Range assignment (run_rangeify)
134/// **Stage 1**: pm_mops + pm_syntactic_sugar (BOTTOM_UP) - Early movement ops
135/// **Stage 2**: pm_load_collapse - Collapse load tensor indexing
136/// **Stage 3**: pm_split_ranges + pm_flatten_range - Range splitting
137/// **Stage 4**: sym + pm_flatten_range - Initial symbolic (TOP_DOWN)
138/// **Stage 5**: pm_simplify_ranges - Simplify/merge ranges
139/// **Stage 6**: apply_opts - Post-range optimization (happens in optimizer)
140#[allow(clippy::mutable_key_type)]
141#[tracing::instrument(skip_all)]
142pub fn rangeify_with_map(
143    sink: Arc<UOp>,
144    pcontig_config: Option<&super::kernel::PcontigConfig>,
145) -> morok_ir::Result<RangeifyResult> {
146    // add_tags: assign sequential integer tags to UOps (Tinygrad rangeify.py:575).
147    // MUST run FIRST — tags track tensor identity through the entire pipeline.
148    let t_stage = std::time::Instant::now();
149    let mut tag_ctx = AddTagsCtx::new();
150    let mut sink = crate::rewrite::graph_rewrite_bottom_up(&add_tags_patterns(), sink, &mut tag_ctx);
151    let uop_list = tag_ctx.uop_list;
152    tracing::debug!(
153        tagged_count = uop_list.len(),
154        node_count = sink.node_count(),
155        elapsed_ms = t_stage.elapsed().as_millis() as u64,
156        "add_tags complete"
157    );
158
159    // Combined early pass (Tinygrad: earliest_rewrites + replace_contiguous, ctx={})
160    // MUST run BEFORE range assignment so rangeify sees a cleaned graph.
161    // Tinygrad (rangeify.py:577): bottom_up=True — patterns see ORIGINAL children.
162    let t_stage = std::time::Instant::now();
163    let early_combined = super::patterns::early_rewrites().with_context::<super::patterns::ReplaceContiguousCtx>()
164        + super::patterns::replace_contiguous();
165    let mut contig_ctx = super::patterns::ReplaceContiguousCtx::new();
166    sink = crate::rewrite::graph_rewrite_bottom_up(&early_combined, sink, &mut contig_ctx);
167    tracing::debug!(
168        uop.tree = sink.tree(),
169        node_count = sink.node_count(),
170        elapsed_ms = t_stage.elapsed().as_millis() as u64,
171        "early rewrites + replace contiguous complete"
172    );
173
174    // Split large reductions BEFORE range assignment (Tinygrad: inside earliest_rewrites)
175    let t_stage = std::time::Instant::now();
176    let mut split_config = super::kernel::SplitReduceOpConfig::default();
177    let split_matcher = super::patterns::split_reduceop_patterns();
178    sink = crate::rewrite::graph_rewrite(&split_matcher, sink, &mut split_config);
179    tracing::debug!(
180        uop.tree = sink.tree(),
181        node_count = sink.node_count(),
182        elapsed_ms = t_stage.elapsed().as_millis() as u64,
183        "split reduceops complete"
184    );
185
186    // =========================================================================
187    // Stage 0: Range assignment + apply rangeify patterns
188    // Tinygrad: run_rangeify() includes pm_generate_realize_map, assign loop,
189    // and pm_apply_rangeify (REDUCE_AXIS→REDUCE, PAD→WHERE, BUFFERIZE+INDEX)
190    // =========================================================================
191    let t_stage = std::time::Instant::now();
192    let (rangeified, indexing_ctx) = super::indexing::run_rangeify(sink)?;
193    sink = rangeified;
194    tracing::debug!(
195        uop.tree = sink.tree(),
196        node_count = sink.node_count(),
197        elapsed_ms = t_stage.elapsed().as_millis() as u64,
198        "Stage 0: range assignment + apply rangeify complete"
199    );
200
201    // =========================================================================
202    // Mega-pass: symbolic + reduce_simplify + buffer_folding + buffer_removal
203    // (Tinygrad rangeify.py:582: symbolic + pm_reduce_simplify + pm_const_buffer_folding + pm_remove_bufferize)
204    //
205    // One fixpoint pass combining all simplification + buffer removal.
206    // Uses PcontigConfig as the shared context (buffer_removal needs it;
207    // other patterns are lifted via with_context()).
208    // =========================================================================
209    {
210        use super::kernel::PcontigConfig;
211        let t_stage = std::time::Instant::now();
212        use std::sync::LazyLock;
213        static MEGA_PASS: LazyLock<crate::TypedPatternMatcher<PcontigConfig>> = LazyLock::new(|| {
214            crate::symbolic::symbolic().with_context::<PcontigConfig>()
215                + super::patterns::pm_reduce_simplify().with_context()
216                + super::patterns::buffer_folding().with_context()
217                + super::patterns::dead_axis_removal().with_context()
218                // pm_mops: Tinygrad includes movement_op_patterns in pm_const_buffer_folding (rangeify.py:260)
219                + super::patterns::movement_op_patterns().with_context()
220                + super::patterns::buffer_removal_with_pcontig()
221        });
222        let mega_pass = &*MEGA_PASS;
223        tracing::debug!(
224            total_patterns = mega_pass.len(),
225            wildcard_count = mega_pass.wildcard_count(),
226            indexed_buckets = mega_pass.indexed_count(),
227            "mega-pass pattern stats"
228        );
229        let mut pcontig = pcontig_config.cloned().unwrap_or_default();
230        sink = crate::rewrite::graph_rewrite(mega_pass, sink, &mut pcontig);
231        tracing::debug!(
232            node_count = sink.node_count(),
233            elapsed_ms = t_stage.elapsed().as_millis() as u64,
234            "mega-pass complete"
235        );
236    }
237
238    // Stages 2a-6 (load_collapse, split_ranges, symbolic+flatten, simplify_ranges,
239    // split_store) now run per-kernel in optimizer::apply_pre_optimization().
240
241    // SINK rebuild: filter sources to tagged valid output types (Tinygrad rangeify.py:585-589).
242    // TODO: Full Tinygrad approach scans backward_slice for ALL tagged nodes and rebuilds
243    // SINK with them. This requires Phase 7 (tag-based becomes_map) to be implemented first.
244    // For now, filter existing SINK sources by tag + op type.
245    if let Op::Sink { sources } = sink.op() {
246        let filtered: Vec<Arc<UOp>> = sources
247            .iter()
248            .filter(|s| {
249                let valid_op = matches!(
250                    s.base().op(),
251                    Op::Bufferize { .. } | Op::MStack { .. } | Op::Const(_) | Op::Param { .. } | Op::After { .. }
252                );
253                valid_op
254            })
255            .cloned()
256            .collect();
257        if !filtered.is_empty() && filtered.len() != sources.len() {
258            tracing::debug!(
259                original = sources.len(),
260                filtered = filtered.len(),
261                "SINK cleanup: removed invalid-type sources after mega-pass"
262            );
263            sink = UOp::sink(filtered);
264        }
265    }
266
267    // Buffer limit enforcement
268    if let Some(device) = super::patterns::extract_device_from_graph(&sink)
269        && let Some(limit) = device.max_buffers()
270    {
271        let t_stage = std::time::Instant::now();
272        let limit_matcher = super::patterns::buffer_limit_patterns(limit);
273        sink = crate::rewrite::graph_rewrite(&limit_matcher, sink, &mut ());
274        tracing::debug!(
275            uop.tree = sink.tree(),
276            elapsed_ms = t_stage.elapsed().as_millis() as u64,
277            "Stage 7b: buffer limit enforcement complete"
278        );
279    }
280
281    // =========================================================================
282    // Stage 8: Post-range optimization happens in optimizer module (apply_opts)
283    // =========================================================================
284
285    // Build RangeifyContext for return
286    let rangeify_ctx = RangeifyContext { range_counter: indexing_ctx.range_counter(), range_map: HashMap::new() };
287
288    Ok(RangeifyResult { sink, context: rangeify_ctx, uop_list })
289}
290
291/// Pattern matcher for range flattening.
292///
293/// Based on Tinygrad's pm_flatten_range (simplify.py:14-17).
294/// Extracts all RANGE operations from nested END/REDUCE/STORE structures.
295pub fn pm_flatten_range() -> &'static crate::TypedPatternMatcher {
296    crate::cached_patterns! {
297        r @ End { computation: _, ranges } if !ranges.is_empty() => |r| flatten_range_impl(r),
298        r @ Reduce { src: _, ranges, reduce_op: _ } if !ranges.is_empty() => |r| flatten_range_impl(r),
299        r @ Store { index: _, value: _, ranges } if !ranges.is_empty() => |r| flatten_range_impl(r),
300    }
301}
302
303// ============================================================================
304// RANGE SPLITTING (pm_split_ranges equivalent)
305// ============================================================================
306
307/// Context for tracking ranges that should be split via modulo decomposition.
308///
309/// Based on Tinygrad's pm_split_ranges (simplify.py:60-64).
310/// When we see `RANGE % const`, we mark the range for splitting at the SINK.
311#[derive(Default)]
312pub struct SplitRangesContext {
313    /// Maps RANGE ids to their modulo constant for decomposition
314    pub marked_ranges: HashMap<u64, i64>,
315    /// RANGE ids that should NOT be substituted (e.g., used in ImageDType stores)
316    protected_ranges: HashSet<u64>,
317}
318
319/// Pattern matcher for range splitting via modulo arithmetic.
320///
321/// Based on Tinygrad's pm_split_ranges (simplify.py:60-64).
322/// This is a context-collecting pass that:
323/// 1. Marks RANGE ops used in `RANGE % const` expressions
324/// 2. Protects ranges used in ImageDType stores from substitution
325/// 3. At SINK, substitutes marked (non-protected) ranges with divmod decomposition
326///
327/// Example transformation for `RANGE(12) % 4`:
328/// - Original: `r = RANGE(12)`
329/// - After: `r_div = RANGE(3) * 4`, `r_mod = RANGE(4)`, substitute `r → r_div + r_mod`
330///
331/// # ImageDType Protection
332///
333/// Range splitting must NOT apply to ImageDType stores because image addressing
334/// uses special 2D coordinates that don't follow standard linear indexing.
335/// Applying range splitting to image stores corrupts the addressing scheme.
336pub fn pm_split_ranges() -> crate::TypedPatternMatcher<SplitRangesContext> {
337    crate::patterns! {
338        @context SplitRangesContext;
339
340        // Mark RANGE % const: record the modulo constant for this range
341        _modop @ Mod(r @ Range { end, axis_id: _, axis_type: _ }, c @ Const(_))
342            if is_divisible_range_end(end, c) => |r, c| {
343                mark_range_mod(ctx, r, c);
344                None // Don't transform yet, just mark
345            },
346
347        // Protect ranges used in ImageDType stores from substitution
348        _store @ Store { index: idx @ Index { buffer: buf, indices: _, gate: _ }, value: _, ranges: _ }
349            if is_image_dtype(buf) => |idx| {
350                protect_ranges_for_image(ctx, idx);
351                None // Don't transform, just protect
352            },
353
354        // At SINK: perform the substitution
355        sink @ Sink { sources: _ } if !ctx.marked_ranges.is_empty() => |sink| {
356            do_split_ranges_substitute(ctx, sink)
357        },
358    }
359}
360
361/// Check if a buffer has ImageDType.
362fn is_image_dtype(buf: &Arc<UOp>) -> bool {
363    matches!(buf.dtype(), DType::Image { .. })
364}
365
366/// Protect all ranges reachable from an INDEX used in an ImageDType store.
367fn protect_ranges_for_image(ctx: &mut SplitRangesContext, idx: &Arc<UOp>) {
368    for node in idx.toposort() {
369        if matches!(node.op(), Op::Range { .. }) {
370            ctx.protected_ranges.insert(node.id);
371            // Also remove from marked_ranges if already marked
372            ctx.marked_ranges.remove(&node.id);
373        }
374    }
375}
376
377/// Extract i64 from a Const UOp.
378fn const_uop_to_i64(c: &Arc<UOp>) -> Option<i64> {
379    match c.op() {
380        Op::Const(cv) => match cv.0 {
381            ConstValue::Int(v) => Some(v),
382            ConstValue::UInt(v) => Some(v as i64),
383            _ => None,
384        },
385        _ => None,
386    }
387}
388
389/// Check if a RANGE end is divisible by the modulo constant.
390fn is_divisible_range_end(end: &Arc<UOp>, c: &Arc<UOp>) -> bool {
391    let Some(end_val) = const_uop_to_i64(end) else {
392        return false;
393    };
394    let Some(mod_val) = const_uop_to_i64(c) else {
395        return false;
396    };
397    mod_val > 0 && end_val % mod_val == 0
398}
399
400/// Mark a range for modulo decomposition.
401fn mark_range_mod(ctx: &mut SplitRangesContext, r: &Arc<UOp>, c: &Arc<UOp>) {
402    // Don't mark if already marked or protected (e.g., used in ImageDType store)
403    if ctx.marked_ranges.contains_key(&r.id) || ctx.protected_ranges.contains(&r.id) {
404        return;
405    }
406    if let Some(mod_val) = const_uop_to_i64(c) {
407        ctx.marked_ranges.insert(r.id, mod_val);
408    }
409}
410
411/// Perform the substitution at SINK level.
412///
413/// For each marked RANGE with `end` divisible by `mod_val`:
414/// - Create `r_outer = RANGE(end / mod_val)` with same axis type, shifted axis_id
415/// - Create `r_inner = RANGE(mod_val)` with same axis type, shifted axis_id
416/// - Substitute `r → r_outer * mod_val + r_inner`
417fn do_split_ranges_substitute(ctx: &mut SplitRangesContext, sink: &Arc<UOp>) -> Option<Arc<UOp>> {
418    use morok_ir::AxisId;
419    use morok_ir::rewrite::graph_rewrite_bottom_up;
420
421    if ctx.marked_ranges.is_empty() {
422        return None;
423    }
424
425    // Build substitution map
426    let mut subs: HashMap<u64, Arc<UOp>> = HashMap::new();
427
428    // Collect all UOps to find the marked ranges and max axis_id
429    let topo = sink.toposort();
430
431    // Find max axis_id across ALL ranges to avoid collisions when creating split ranges
432    let mut max_axis_id: usize = 0;
433    for uop in &topo {
434        if let Op::Range { axis_id, .. } = uop.op() {
435            max_axis_id = max_axis_id.max(axis_id.value());
436        }
437    }
438    let mut next_id = max_axis_id + 1;
439
440    for uop in &topo {
441        // Skip protected ranges (e.g., used in ImageDType stores)
442        if ctx.protected_ranges.contains(&uop.id) {
443            continue;
444        }
445        if let Some(&mod_val) = ctx.marked_ranges.get(&uop.id)
446            && let Op::Range { end, axis_type, .. } = uop.op()
447        {
448            let Some(end_val) = const_uop_to_i64(end) else {
449                continue;
450            };
451
452            // Create outer range with unique axis_id
453            let outer_end = end_val / mod_val;
454            let outer_range = UOp::range_axis(UOp::index_const(outer_end), AxisId::Renumbered(next_id), *axis_type);
455            next_id += 1;
456
457            // Create inner range with unique axis_id
458            let inner_range = UOp::range_axis(UOp::index_const(mod_val), AxisId::Renumbered(next_id), *axis_type);
459            next_id += 1;
460
461            // Substitution: r → outer * mod_val + inner
462            let mod_const = UOp::index_const(mod_val);
463            let outer_scaled = outer_range.mul(&mod_const);
464            let combined = outer_scaled.add(&inner_range);
465
466            subs.insert(uop.id, combined);
467        }
468    }
469
470    if subs.is_empty() {
471        return None;
472    }
473
474    // Apply substitutions using the substitute pattern
475    let substitute_pm = crate::patterns! {
476        r @ Range { end: _, axis_id: _, axis_type: _ } if subs.contains_key(&r.id) => {
477            subs.get(&r.id).cloned()
478        },
479    };
480
481    let result = graph_rewrite_bottom_up(&substitute_pm, sink.clone(), &mut ());
482
483    // Clear the context after substitution
484    ctx.marked_ranges.clear();
485
486    Some(result)
487}
488
489// ============================================================================
490// TRANSFORM HELPERS (movement ops → BUFFERIZE + INDEX)
491// ============================================================================
492
493/// Transform a UOp's sources by adding BUFFERIZE + INDEX where needed.
494pub fn transform_sources_with_bufferize(x: &Arc<UOp>, ctx: &mut IndexingContext) -> Option<Vec<Arc<UOp>>> {
495    if matches!(x.op(), Op::Bufferize { .. } | Op::Index { .. } | Op::After { .. }) {
496        return None;
497    }
498
499    let sources = x.op().sources();
500    if sources.is_empty() {
501        return None;
502    }
503
504    // Tinygrad: INDEX is only added when `x in ctx.range_map`.
505    // For SINK (not in range_map), realized sources still get BUFFERIZE but no INDEX.
506    let input_ranges = if let Some((ranges, _)) = ctx.get_ranges(x) { ranges.clone() } else { Vec::new() };
507
508    let mut new_sources = Vec::with_capacity(sources.len());
509    let mut any_changed = false;
510
511    for src in sources.iter() {
512        let new_src = transform_single_source(x, src, &input_ranges, ctx);
513        if !Arc::ptr_eq(&new_src, src) {
514            any_changed = true;
515        }
516        new_sources.push(new_src);
517    }
518
519    if any_changed { Some(new_sources) } else { None }
520}
521
522/// Flatten multi-range BUFFERIZE to single-range via RESHAPE to 1D.
523///
524/// Matches Tinygrad's `flatten_bufferize` (rangeify.py:381-389):
525/// 1. Reshapes multi-dim ranges to a single flat index via apply_reshape_ranges
526/// 2. Creates new BUFFERIZE with single computed range
527/// 3. Wraps with RESHAPE back to original shape for downstream movement ops
528/// 4. For symbolic range ends, adds SHRINK to symbolic shape
529///
530/// After this, `bufferize_to_store` only sees single-range BUFFERIZE.
531fn flatten_bufferize(bufferize: &Arc<UOp>) -> Option<Arc<UOp>> {
532    let Op::Bufferize { compute, ranges, opts } = bufferize.op() else { return None };
533    if ranges.len() <= 1 {
534        return None;
535    }
536    // Extract shape from ranges: RANGE(end) → SInt::from(end), CONST(0) → 1
537    let shape: Vec<morok_ir::SInt> = ranges
538        .iter()
539        .map(|r| match r.op() {
540            Op::Range { end, .. } => morok_ir::SInt::from(end.clone()),
541            _ => morok_ir::SInt::from(1usize),
542        })
543        .collect();
544
545    // Flatten: apply_reshape_ranges(in_shape=(prod,), out_shape=shape, rngs=ranges)
546    let flat_shape = vec![morok_ir::sint_prod(&shape)];
547    let ranges_vec: Vec<Arc<UOp>> = ranges.iter().cloned().collect();
548    let flat_indices = super::indexing::apply_reshape_ranges(&flat_shape, &shape, &ranges_vec);
549    assert_eq!(flat_indices.len(), 1, "flatten_bufferize: expected 1 flat index, got {}", flat_indices.len());
550    // New BUFFERIZE with single range
551    let flat_buf = UOp::bufferize(compute.clone(), vec![flat_indices[0].clone()], opts.clone());
552
553    // RESHAPE back to original shape (Tinygrad: ret.forced_reshape(x.shape))
554    let shape_smallvec: Shape = shape.iter().cloned().collect();
555    let reshaped = flat_buf.try_reshape(&shape_smallvec).expect("flatten_bufferize: try_reshape failed");
556
557    // For symbolic range ends, add SHRINK to symbolic shape
558    // Tinygrad: if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs)
559    let has_symbolic =
560        ranges.iter().any(|r| matches!(r.op(), Op::Range { end, .. } if !matches!(end.op(), Op::Const(_))));
561
562    if has_symbolic {
563        let sym_ranges: Vec<(morok_ir::SInt, morok_ir::SInt)> = ranges
564            .iter()
565            .map(|r| match r.op() {
566                Op::Range { end, .. } => (morok_ir::SInt::from(0usize), morok_ir::SInt::from(end.clone())),
567                _ => (morok_ir::SInt::from(0usize), morok_ir::SInt::from(1usize)),
568            })
569            .collect();
570        Some(reshaped.try_shrink(&sym_ranges).expect("flatten_bufferize: try_shrink failed for symbolic ranges"))
571    } else {
572        Some(reshaped)
573    }
574}
575
576/// Push movement op through AFTER: `AFTER(MOVEMENT(x), deps) → MOVEMENT(AFTER(x, deps))`.
577///
578/// Matches Tinygrad's pm_mops rule 2 (rangeify.py:28-29):
579///   `UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)`
580/// Directly reuses the original movement op's parameters (no roundtrip/validation).
581pub(crate) fn push_movement_through_after(mop: &Arc<UOp>, deps: &SmallVec<[Arc<UOp>; 4]>) -> Option<Arc<UOp>> {
582    let inner_src = &mop.op().sources()[0];
583    let new_after = inner_src.after(deps.clone());
584    // Re-create the movement op with new_after as source, reusing original parameters.
585    // Tinygrad: UOp(r.op, r.dtype, (new_after,)+r.src[1:], r.arg)
586    let new_op = match mop.op() {
587        Op::Reshape { new_shape, .. } => Op::Reshape { src: new_after, new_shape: new_shape.clone() },
588        Op::Permute { axes, .. } => Op::Permute { src: new_after, axes: axes.clone() },
589        Op::Expand { new_shape, .. } => Op::Expand { src: new_after, new_shape: new_shape.clone() },
590        Op::Pad { begin_pads, end_pads, .. } => {
591            Op::Pad { src: new_after, begin_pads: begin_pads.clone(), end_pads: end_pads.clone() }
592        }
593        Op::Shrink { begins, ends, .. } => Op::Shrink { src: new_after, begins: begins.clone(), ends: ends.clone() },
594        Op::Flip { axes, .. } => Op::Flip { src: new_after, axes: axes.clone() },
595        _ => return None,
596    };
597    Some(UOp::new(new_op, mop.dtype()))
598}
599
600/// Transform a single source by adding BUFFERIZE + INDEX if needed.
601///
602/// Non-recursive: only handles immediate buffer-like and realized sources.
603/// Movement ops and compute ops are left for the BPM rewrite engine to
604/// process individually (matching Tinygrad's `create_bufferize_and_index_based_on_ranges`).
605///
606/// INDEX nodes are created with a single linear index (matching Tinygrad),
607/// computed from the buffer's dimensional ranges and the consumer's index
608/// expressions. This eliminates the need for a later linearization pass.
609pub(crate) fn transform_single_source(
610    consumer: &Arc<UOp>,
611    src: &Arc<UOp>,
612    input_ranges: &[Arc<UOp>],
613    ctx: &mut IndexingContext,
614) -> Arc<UOp> {
615    // Case 1: Buffer-like op → add multi-index INDEX
616    // Unlike Case 2 (BUFFERIZE), we can't linearize here because the buffer's
617    // dimensional structure isn't directly available from ctx — the output_ranges
618    // may contain PAD validity expressions, not clean RANGE ops.
619    // Multi-index INDEX is preserved through the pipeline; codegen linearizes at render time.
620    if matches!(
621        src.op(),
622        Op::Buffer { .. }
623            | Op::Param { .. }
624            | Op::BufferView { .. }
625            | Op::MStack { .. }
626            | Op::MSelect { .. }
627            | Op::After { .. }
628    ) {
629        if !input_ranges.is_empty() {
630            return UOp::index()
631                .buffer(Arc::clone(src))
632                .indices(input_ranges.to_vec())
633                .call()
634                .expect("Failed to create INDEX for buffer source");
635        }
636        return Arc::clone(src);
637    }
638
639    // Case 2: source needs realization → wrap in BUFFERIZE + INDEX
640    let realize_axes_opt = ctx.get_realize_axes(src).cloned();
641    if let Some(ref realize_axes) = realize_axes_opt {
642        let (_, output_ranges) = ctx.get_ranges(src).expect("Realized op must have ranges");
643
644        let closed_ranges: Vec<_> = output_ranges
645            .iter()
646            .enumerate()
647            .filter(|(i, _)| realize_axes.contains(i))
648            .map(|(_, r)| Arc::clone(r))
649            .collect();
650
651        // Tinygrad (indexing.py:67): removable = x.op is not Ops.COPY and s.op not in ALWAYS_CONTIGUOUS
652        let is_copy_consumer = matches!(consumer.op(), Op::Copy { .. });
653        let is_always_contiguous_src = super::indexing::is_always_contiguous(src);
654        let removable = !is_copy_consumer && !is_always_contiguous_src;
655        let addrspace = if output_ranges.len() == realize_axes.len() { AddrSpace::Global } else { AddrSpace::Local };
656        tracing::debug!(
657            src_id = src.id,
658            src_op = src.op().as_ref(),
659            consumer_id = consumer.id,
660            consumer_op = consumer.op().as_ref(),
661            realize_axes = ?realize_axes,
662            output_ranges_len = output_ranges.len(),
663            addrspace = ?addrspace,
664            removable = removable,
665            "BUFFERIZE decision"
666        );
667        // Propagate source device to BUFFERIZE opts (Tinygrad indexing.py:69: device=s.device)
668        let device = src.device_spec();
669        let opts = BufferizeOpts { device, addrspace, removable };
670
671        // Tinygrad (indexing.py:71): tag=s.tag if GLOBAL, else None
672        let buf_tag = if addrspace == AddrSpace::Global { src.tag().clone() } else { None };
673        let bufferized = UOp::bufferize(Arc::clone(src), closed_ranges.clone(), opts);
674        let bufferized = if let Some(t) = buf_tag { bufferized.with_tag(t) } else { bufferized };
675
676        let index_ranges: Vec<_> = input_ranges
677            .iter()
678            .enumerate()
679            .filter(|(i, _)| realize_axes.contains(i))
680            .map(|(_, r)| Arc::clone(r))
681            .collect();
682
683        if !index_ranges.is_empty() {
684            // Create multi-index INDEX; linearization happens in pm_add_buffers_patterns
685            // via linearize_index_on_bufferize (BPM pattern).
686            return UOp::index()
687                .buffer(bufferized)
688                .indices(index_ranges)
689                .call()
690                .expect("Failed to create INDEX after BUFFERIZE");
691        } else {
692            return bufferized;
693        }
694    }
695
696    // Default: no transformation — BPM engine handles movement/compute ops individually
697    Arc::clone(src)
698}
699
700// ============================================================================
701// BUFFERIZE TO STORE CONVERSION
702// ============================================================================
703
704// ============================================================================
705// HELPER FUNCTIONS FOR BUFFERIZE_TO_STORE
706// ============================================================================
707
708/// Apply movement ops chain in reverse order.
709/// Walks from chain root to base using pattern matching.
710/// Uses existing .base() method at ir/src/uop/core.rs:425-438.
711fn apply_movement_ops_chain(result: &Arc<UOp>, chain: &Arc<UOp>) -> Option<Arc<UOp>> {
712    let mut mops = Vec::new();
713    let mut walk = chain.clone();
714
715    // Walk chain collecting movement ops using pattern matching
716    while walk.op().is_movement() {
717        mops.push(walk.clone());
718        // Extract src via pattern matching
719        walk = match walk.op() {
720            Op::Reshape { src, .. }
721            | Op::Permute { src, .. }
722            | Op::Expand { src, .. }
723            | Op::Pad { src, .. }
724            | Op::Shrink { src, .. }
725            | Op::Flip { src, .. } => src.clone(),
726            _ => break,
727        };
728    }
729
730    // Apply in reverse order
731    let mut current = result.clone();
732    for mop in mops.into_iter().rev() {
733        current = apply_single_movement_op(&current, mop.op())?;
734    }
735
736    Some(current)
737}
738
739/// Apply a single movement operation.
740///
741/// Note: This extracts shape from the movement op's stored source (which has the
742/// target shape after the movement) rather than from UOp shape metadata.
743fn apply_single_movement_op(uop: &Arc<UOp>, op: &Op) -> Option<Arc<UOp>> {
744    match op {
745        Op::Reshape { new_shape, .. } => {
746            let shape = extract_shape_from_uop(new_shape)?;
747            uop.try_reshape(&shape).ok()
748        }
749        Op::Permute { axes, .. } => uop.try_permute(axes.clone()).ok(),
750        Op::Expand { new_shape, .. } => {
751            let shape = extract_shape_from_uop(new_shape)?;
752            uop.try_expand(&shape).ok()
753        }
754        Op::Pad { begin_pads, end_pads, .. } => {
755            let begins = extract_shape_from_uop(begin_pads)?;
756            let ends = extract_shape_from_uop(end_pads)?;
757            let padding: Vec<_> = begins.into_iter().zip(ends).collect();
758            uop.try_pad(&padding).ok()
759        }
760        Op::Shrink { begins, ends, .. } => {
761            let begin_shape = extract_shape_from_uop(begins)?;
762            let end_shape = extract_shape_from_uop(ends)?;
763            let ranges: Vec<_> = begin_shape.into_iter().zip(end_shape).collect();
764            uop.try_shrink(&ranges).ok()
765        }
766        Op::Flip { axes, .. } => uop.try_flip(axes.clone()).ok(),
767        _ => None,
768    }
769}
770
771/// Extract shape from a UOp (for movement op parameters).
772/// Handles VECTORIZE, CONST, and VCONST patterns.
773fn extract_shape_from_uop(shape_uop: &Arc<UOp>) -> Option<Shape> {
774    use morok_ir::SInt;
775    match shape_uop.op() {
776        // VECTORIZE with Index-typed elements
777        Op::Vectorize { elements } => Some(elements.iter().cloned().map(SInt::from).collect()),
778        // Single CONST value (for 1D shapes)
779        Op::Const(const_hash) => match const_hash.0 {
780            ConstValue::Int(v) if v >= 0 => Some(smallvec![SInt::from(v as usize)]),
781            ConstValue::UInt(v) => Some(smallvec![SInt::from(v as usize)]),
782            _ => None,
783        },
784        // VConst for multiple concrete dimensions
785        Op::VConst { values } => {
786            let mut dims = smallvec![];
787            for val in values {
788                match val {
789                    ConstValue::Int(v) if *v >= 0 => dims.push(SInt::from(*v as usize)),
790                    ConstValue::UInt(v) => dims.push(SInt::from(*v as usize)),
791                    _ => return None,
792                }
793            }
794            Some(dims)
795        }
796        _ => None,
797    }
798}
799
800/// Create a LOOP range from an OUTER range with the same axis_id.
801fn create_loop_range_from_outer(outer: &Arc<UOp>, size: usize) -> Option<Arc<UOp>> {
802    use morok_ir::AxisType;
803    let Op::Range { axis_id, .. } = outer.op() else {
804        return None;
805    };
806    Some(UOp::range_axis(UOp::index_const(size as i64), *axis_id, AxisType::Loop))
807}
808
809/// Convert ReduceOp to binary operation.
810fn reduce_op_to_binary(op: morok_ir::ReduceOp, lhs: &Arc<UOp>, rhs: &Arc<UOp>) -> Option<Arc<UOp>> {
811    use morok_ir::types::{BinaryOp, ReduceOp};
812    let dtype = lhs.dtype();
813    Some(match op {
814        ReduceOp::Add => UOp::new(Op::Binary(BinaryOp::Add, lhs.clone(), rhs.clone()), dtype),
815        ReduceOp::Mul => UOp::new(Op::Binary(BinaryOp::Mul, lhs.clone(), rhs.clone()), dtype),
816        ReduceOp::Max => UOp::new(Op::Binary(BinaryOp::Max, lhs.clone(), rhs.clone()), dtype),
817        ReduceOp::Min => {
818            // Min uses WHERE(a < b, a, b) pattern
819            let cond = UOp::new(Op::Binary(BinaryOp::Lt, lhs.clone(), rhs.clone()), morok_dtype::DType::Bool);
820            UOp::try_where(cond, lhs.clone(), rhs.clone()).expect("reduce_op_to_binary: try_where failed for Min")
821        }
822    })
823}
824
825/// Calculate buffer size from RANGE operations.
826/// Calculate buffer size from BUFFERIZE ranges.
827/// Matches Tinygrad: `size = prod(x.shape)` where `x.shape = [int(r.vmax+1) for r in src[1:]]`.
828/// Each range contributes `vmax+1` to the product (RANGE UOps have vmax = end-1, so vmax+1 = end).
829/// For flattened BUFFERIZE (single computed expression), vmax+1 gives the total flat size.
830fn calculate_size_from_ranges(ranges: &SmallVec<[Arc<UOp>; 4]>) -> usize {
831    if ranges.is_empty() {
832        return 1;
833    }
834
835    ranges
836        .iter()
837        .map(|r| {
838            // Tinygrad: int(r.vmax+1) — works for both RANGE and computed expressions
839            let vmax = r.vmax();
840            match vmax {
841                ConstValue::Int(v) if *v >= 0 => (*v + 1) as usize,
842                ConstValue::UInt(v) => (*v + 1) as usize,
843                other => panic!(
844                    "Cannot allocate buffer: range vmax resolved to {:?}. \
845                     Buffers require concrete sizes (Tinygrad: 'no symbolic sized buffers')",
846                    other
847                ),
848            }
849        })
850        .product()
851}
852
853/// Sort ranges by (axis_id, axis_type) for correct row-major linearization.
854///
855/// Tinygrad reference (rangeify.py:303):
856///   rngs = sorted(idx.ranges, key=lambda x: x.arg)
857///
858/// Tinygrad's RANGE.arg is (axis_id, axis_type), so sorting uses both.
859/// This ensures that multi-dimensional ranges are linearized in the correct
860/// order regardless of their insertion order in the graph.
861fn sort_ranges_by_axis_id(ranges: &SmallVec<[Arc<UOp>; 4]>) -> SmallVec<[Arc<UOp>; 4]> {
862    let mut sorted: Vec<_> = ranges.iter().cloned().collect();
863    sorted.sort_by_key(|r| {
864        if let Op::Range { axis_id, axis_type, .. } = r.op() {
865            // Sort by (axis_id, axis_type) to match Tinygrad's sorting by x.arg
866            (axis_id.value(), axis_type_ordinal(*axis_type))
867        } else {
868            (usize::MAX, u8::MAX)
869        }
870    });
871    sorted.into()
872}
873
874/// Convert AxisType to ordinal for consistent sorting.
875/// Order matches enum definition order in ir/src/types.rs.
876fn axis_type_ordinal(at: AxisType) -> u8 {
877    match at {
878        AxisType::Outer => 0,
879        AxisType::Global => 1,
880        AxisType::Warp => 2,
881        AxisType::Local => 3,
882        AxisType::Loop => 4,
883        AxisType::GroupReduce => 5,
884        AxisType::Reduce => 6,
885        AxisType::Upcast => 7,
886        AxisType::Unroll => 8,
887        AxisType::Thread => 9,
888        AxisType::Placeholder => 10,
889    }
890}
891
892/// Collect RANGE UOps from BUFFERIZE ranges, traversing flattened expressions.
893///
894/// After `flatten_bufferize`, `ranges[0]` may be a computed expression (Add/Mul of RANGEs)
895/// rather than a direct RANGE UOp. This helper traverses all range entries:
896/// - Direct RANGE UOps are collected immediately
897/// - Non-CONST expressions are traversed via `.ranges()` to find embedded RANGE UOps
898/// - CONST entries (collapsed singleton dims) are skipped
899/// - Deduplicates by UOp id
900fn collect_range_uops(ranges: &SmallVec<[Arc<UOp>; 4]>) -> SmallVec<[Arc<UOp>; 4]> {
901    let mut collected = SmallVec::new();
902    for r in ranges.iter() {
903        if matches!(r.op(), Op::Range { .. }) {
904            collected.push(r.clone());
905        } else if !matches!(r.op(), Op::Const(_)) {
906            for rng in r.ranges().iter() {
907                if !collected.iter().any(|c: &Arc<UOp>| c.id == rng.id) {
908                    collected.push(rng.clone());
909                }
910            }
911        }
912    }
913    collected
914}
915
916/// Convert BUFFERIZE operation to STORE with buffer allocation and END wrapping.
917///
918/// # Arguments
919///
920/// * `bufferize_op` - The BUFFERIZE UOp to convert
921/// * `ctx` - Kernel context for tracking buffers and generating IDs
922/// * `allow_locals` - If false, treat local address space as global (Tinygrad: pm_add_buffers).
923///   If true, create DEFINE_LOCAL for local address space (Tinygrad: pm_add_buffers_local).
924pub fn bufferize_to_store(bufferize_op: &Arc<UOp>, ctx: &mut KernelContext, allow_locals: bool) -> Option<Arc<UOp>> {
925    let (compute, ranges, opts) = match bufferize_op.op() {
926        Op::Bufferize { compute, ranges, opts } => {
927            tracing::debug!(
928                bufferize_id = bufferize_op.id,
929                compute_id = compute.id,
930                ranges_len = ranges.len(),
931                allow_locals = allow_locals,
932                "bufferize_to_store: CONVERTING BUFFERIZE to STORE→AFTER"
933            );
934            (compute, ranges, opts)
935        }
936        _ => return None,
937    };
938
939    // Calculate size and base dtype upfront (needed for both buffer creation and INDEX dtype)
940    let size = calculate_size_from_ranges(ranges);
941    let base_dtype = match bufferize_op.dtype() {
942        DType::Ptr { base, .. } => (*base).clone(),
943        other => other,
944    };
945
946    // Calculate sdtype explicitly like Tinygrad (rangeify.py:306):
947    //   sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace)
948    // This is the pointer type used for STORE targets, ensuring consistent
949    // size and addrspace across all INDEX operations in this function.
950    let sdtype = base_dtype.clone().ptr(Some(size), opts.addrspace);
951
952    // Get end_ranges for wrapping stores.
953    // Tinygrad: `.end(*rngs)` where `rngs = sorted(idx.ranges, ...)`.
954    let end_ranges: SmallVec<[Arc<UOp>; 4]> = sort_ranges_by_axis_id(&collect_range_uops(ranges));
955
956    // =========================================================================
957    // Case 1: ASSIGN → STORE (reuse existing buffer from ASSIGN target)
958    // Tinygrad reference: rangeify.py:307-320
959    // =========================================================================
960    if let Op::Assign { target, value, movement_ops } = compute.op() {
961        // Target must be an INDEX pointing to a buffer
962        let Op::Index { buffer, indices, gate } = target.op() else {
963            return None;
964        };
965
966        // Create store target with explicit sdtype (Tinygrad: assign_target.replace(dtype=sdtype))
967        let store_target = UOp::index()
968            .buffer(buffer.clone())
969            .indices(indices.to_vec())
970            .maybe_gate(gate.clone())
971            .dtype(sdtype.clone())
972            .call()
973            .expect("bufferize_to_store: failed to create INDEX for ASSIGN target");
974
975        // Create STORE and wrap with END
976        let store = store_target.store_value(value.clone());
977        let do_store = if end_ranges.is_empty() { store } else { store.end(end_ranges.clone()) };
978
979        // Apply movement ops in reverse order
980        let mut result = buffer.after(smallvec![do_store]);
981        if let Some(mops_chain) = movement_ops {
982            result = apply_movement_ops_chain(&result, mops_chain)?;
983        }
984
985        ctx.map_buffer(bufferize_op.clone(), result.clone());
986        return Some(result);
987    }
988
989    // =========================================================================
990    // Case 2: OUTER REDUCE with zero initialization
991    // Tinygrad reference: rangeify.py:323-332
992    // =========================================================================
993    if let Op::Reduce { src: reduce_src, ranges: reduce_ranges, reduce_op } = compute.op() {
994        // OUTER reduce case: exactly ONE range that is OUTER type
995        // Tinygrad: len(x.src[0].src) == 2 means src + 1 range
996        if reduce_ranges.len() == 1
997            && let Op::Range { axis_type, .. } = reduce_ranges[0].op()
998            && *axis_type == AxisType::Outer
999        {
1000            // Must be global address space
1001            if opts.addrspace != AddrSpace::Global {
1002                return None;
1003            }
1004
1005            let outer_range = reduce_ranges[0].clone();
1006            let device = opts.device.clone().unwrap_or(morok_ir::DeviceSpec::Cpu);
1007
1008            // Create buffer
1009            let buf = UOp::new_buffer(device, size, base_dtype.clone());
1010
1011            // Create zero-init range (same axis_id but AxisType::Loop)
1012            let zero_range = create_loop_range_from_outer(&outer_range, size)?;
1013
1014            // Get identity value for reduce op
1015            use crate::symbolic::dce::reduce_identity;
1016            let identity = reduce_identity(*reduce_op, base_dtype.clone());
1017
1018            // Zero-initialize: buf[zero_range] = identity
1019            let zero_idx = UOp::index()
1020                .buffer(buf.clone())
1021                .indices(vec![zero_range.clone()])
1022                .dtype(sdtype.clone())
1023                .call()
1024                .expect("bufferize_to_store: failed to create INDEX for OUTER REDUCE zero-init");
1025            let zero_store = zero_idx.store_value(identity).end(smallvec![zero_range.clone()]);
1026            let buf_zeroed = buf.after(smallvec![zero_store]);
1027
1028            // Use BUFFERIZE's index directly (already flattened by flatten_bufferize).
1029            // Matches Tinygrad: `bufi = buf.index(idx, dtype=sdtype)` where idx = x.src[1]
1030            debug_assert!(
1031                ranges.len() <= 1 || ranges.iter().all(|r| matches!(r.op(), Op::Const(_))),
1032                "bufferize_to_store: unexpected multi-range in OUTER REDUCE after flatten_bufferize"
1033            );
1034            let idx = if ranges.len() == 1 && !matches!(ranges[0].op(), Op::Const(_)) {
1035                ranges[0].clone()
1036            } else if !end_ranges.is_empty() {
1037                sort_ranges_by_axis_id(&end_ranges)[0].clone()
1038            } else {
1039                UOp::index_const(0)
1040            };
1041
1042            // Collect RANGE UOps from the index expression for END wrapping
1043            let sorted_end_ranges = sort_ranges_by_axis_id(&collect_range_uops(ranges));
1044
1045            // Accumulation: buf[idx] = buf[idx] OP reduce_src (Tinygrad: bufi = buf.index(idx, dtype=sdtype))
1046            let buf_idx = UOp::index()
1047                .buffer(buf_zeroed.clone())
1048                .indices(vec![idx])
1049                .dtype(sdtype.clone())
1050                .call()
1051                .expect("bufferize_to_store: failed to create INDEX for OUTER REDUCE accumulation");
1052            let loaded = UOp::load().buffer(buf_zeroed.clone()).index(buf_idx.clone()).call();
1053            let accumulated = reduce_op_to_binary(*reduce_op, &loaded, reduce_src)?;
1054
1055            // Wrap store with both collected end_ranges AND outer_range
1056            let do_store = buf_idx.store_value(accumulated).end(sorted_end_ranges).end(smallvec![outer_range]);
1057
1058            let result = buf_zeroed.after(smallvec![do_store]);
1059            ctx.map_buffer(bufferize_op.clone(), result.clone());
1060            return Some(result);
1061        }
1062    }
1063
1064    // Determine effective address space based on allow_locals parameter
1065    // Tinygrad has two matchers:
1066    // - pm_add_buffers (allow_locals=False): skips local buffers entirely (returns None)
1067    // - pm_add_buffers_local (allow_locals=True): creates DEFINE_LOCAL for local
1068    //
1069    // When allow_locals=false and the buffer is LOCAL, we return None to leave the
1070    // BUFFERIZE as-is. This matches Tinygrad's behavior where local buffers are only
1071    // converted during codegen (pm_add_buffers_local), NOT during kernel splitting.
1072    if !allow_locals && opts.addrspace == AddrSpace::Local {
1073        return None;
1074    }
1075    let effective_addrspace = opts.addrspace;
1076
1077    let buffer = if let Some(existing_buffer) = ctx.get_buffer(bufferize_op) {
1078        existing_buffer.clone()
1079    } else if effective_addrspace == AddrSpace::Global {
1080        // Create BUFFER node (like Tinygrad's UOp.new_buffer)
1081        // The BUFFER → PARAM conversion happens later in split_store
1082        let device = opts.device.clone().unwrap_or(morok_ir::DeviceSpec::Cpu);
1083        UOp::new_buffer(device, size, base_dtype.clone())
1084    } else {
1085        // For local address space (only when allow_locals=true), create DEFINE_LOCAL directly (like Tinygrad)
1086        let local_ptr_dtype = base_dtype.clone().ptr(Some(size), opts.addrspace);
1087        let local_id = ctx.next_local();
1088        UOp::define_local(local_id, local_ptr_dtype)
1089    };
1090
1091    // Use ptr=true to keep Ptr dtype for STORE targets (Tinygrad-aligned).
1092    // This ensures INDEX returns pointer type, which STORE codegen expects.
1093    // ptr=true is equivalent to setting dtype to buffer.dtype(), but is the
1094    // idiomatic way per Tinygrad's buf.index(idx, ptr=True).
1095
1096    // Collect active RANGE UOps from the ranges.
1097    // Tinygrad: `rngs = sorted(idx.ranges, ...)` — traverses expression tree for RANGE UOps.
1098    let active_ranges: SmallVec<[Arc<UOp>; 4]> = collect_range_uops(ranges);
1099
1100    // Sort active ranges by axis_id for correct row-major linearization (Tinygrad: rangeify.py:303)
1101    let sorted_ranges = sort_ranges_by_axis_id(&active_ranges);
1102
1103    // Broadcast buffer for STORE-side INDEX only (Tinygrad: buf.broadcast(count).index(idx))
1104    // The AFTER return uses the unbroadcast buffer so consumers can broadcast it properly.
1105    let vcount = compute.dtype().vcount();
1106    let store_buffer = if vcount > 1 { buffer.broadcast(vcount) } else { buffer.clone() };
1107
1108    let store_target = if !sorted_ranges.is_empty() {
1109        // After flatten_bufferize, ranges[0] may be the already-linearized flat index.
1110        // Use it directly. For non-flattened single-range, the RANGE is used directly.
1111        // Matches Tinygrad: buf.index(idx, dtype=sdtype)
1112        assert!(
1113            ranges.len() <= 1 || ranges.iter().all(|r| matches!(r.op(), Op::Const(_))),
1114            "bufferize_to_store: unexpected multi-range in general path after flatten_bufferize"
1115        );
1116        let idx = if ranges.len() == 1 && !matches!(ranges[0].op(), Op::Const(_)) {
1117            // Single range element (possibly flattened expression or RANGE)
1118            ranges[0].clone()
1119        } else {
1120            // Multiple RANGE UOps (shouldn't happen after flatten, but fallback)
1121            sorted_ranges[0].clone()
1122        };
1123        UOp::index()
1124            .buffer(store_buffer)
1125            .indices(vec![idx])
1126            .dtype(sdtype.clone())
1127            .call()
1128            .expect("Failed to create INDEX for BUFFERIZE-to-STORE conversion")
1129    } else {
1130        // Scalar store: create INDEX with buffer + index 0 and explicit sdtype
1131        UOp::index()
1132            .buffer(store_buffer)
1133            .indices(vec![UOp::index_const(0)])
1134            .dtype(sdtype.clone())
1135            .call()
1136            .expect("Failed to create INDEX for scalar STORE")
1137    };
1138
1139    // Create STORE and wrap with END if there are output ranges.
1140    // This matches Tinygrad's architecture: .store().end(*rngs)
1141    //
1142    // The END wrapper is critical because:
1143    // 1. split_store looks for END { computation: STORE, ranges } pattern
1144    // 2. END.ranges define the iteration space for the OUTPUT (not internal computations)
1145    // 3. For scalar stores (e.g., REDUCE results), no END wrapping (ranges is empty)
1146    // 4. REDUCE's loop is handled by pm_reduce which creates its own END internally
1147    // NOTE: STORE takes (index, value) - buffer is accessed via index.buffer
1148    let store = store_target.store_value(compute.clone());
1149
1150    // Determine END ranges: use only actual RANGE ops from BUFFERIZE (Tinygrad-aligned).
1151    //
1152    // Tinygrad's `rngs = sorted(idx.ranges, ...)` naturally excludes CONST(0) entries
1153    // because `.ranges` only collects RANGE UOps. END should only wrap with actual
1154    // iteration ranges, not collapsed singleton dimensions.
1155    let end_ranges: SmallVec<[Arc<UOp>; 4]> = sorted_ranges.clone();
1156
1157    let mut do_store = if !end_ranges.is_empty() { store.end(end_ranges) } else { store };
1158
1159    if opts.addrspace == AddrSpace::Local {
1160        do_store = do_store.barrier(SmallVec::new());
1161    }
1162
1163    let result = buffer.after(SmallVec::from_elem(do_store, 1));
1164    ctx.map_buffer(bufferize_op.clone(), result.clone());
1165
1166    Some(result)
1167}
1168
1169// ============================================================================
1170// REDUCTION SIMPLIFICATIONS
1171// ============================================================================
1172
1173/// Partition ranges into parented and unparented.
1174#[allow(clippy::mutable_key_type)]
1175pub(crate) fn partition_reduce_ranges(
1176    ranges: &SmallVec<[Arc<UOp>; 4]>,
1177    src_ranges: &HashSet<UOpKey>,
1178) -> (SmallVec<[Arc<UOp>; 4]>, Vec<Arc<UOp>>) {
1179    let mut parented = SmallVec::new();
1180    let mut unparented = Vec::new();
1181
1182    for range in ranges {
1183        let key = UOpKey(Arc::clone(range));
1184        if src_ranges.contains(&key) {
1185            parented.push(Arc::clone(range));
1186        } else {
1187            unparented.push(Arc::clone(range));
1188        }
1189    }
1190
1191    (parented, unparented)
1192}
1193
1194pub(crate) fn get_range_size(range: &Arc<UOp>) -> Option<Arc<UOp>> {
1195    if let Op::Range { end, .. } = range.op() { Some(Arc::clone(end)) } else { None }
1196}
1197
1198/// Collapse REDUCE(ADD) by algebraic simplification following Tinygrad's algorithm.
1199///
1200/// Core reduce collapse algorithm — parameterized by pattern matcher.
1201///
1202/// For each reduce range:
1203/// 1. Gated toposort to find nodes "in scope" of the range
1204/// 2. Replace external inputs (nodes NOT in scope) with synthetic DEFINE_VAR
1205/// 3. Wrap substituted body in a synthetic REDUCE
1206/// 4. Run algebraic patterns (bound-from-below/above, distributive, etc.)
1207/// 5. If REDUCE is eliminated (no_range), reverse-substitute back
1208///
1209/// Based on Tinygrad's `reduce_collapse` (simplify.py:121-134).
1210/// Parameterized by `pm` following Tinygrad's `def reduce_collapse(red, u, pm=pm_reduce_collapse)`.
1211#[allow(clippy::mutable_key_type)]
1212fn reduce_collapse_with(src: &Arc<UOp>, ranges: &[Arc<UOp>], pm: &crate::TypedPatternMatcher<()>) -> Option<Arc<UOp>> {
1213    use morok_ir::ReduceOp;
1214
1215    if ranges.is_empty() {
1216        return None;
1217    }
1218
1219    let mut u = Arc::clone(src);
1220
1221    for range in ranges {
1222        // 1. Gated toposort: find nodes "in scope" of this range
1223        let range_key = UOpKey(range.clone());
1224        let in_scope: HashSet<UOpKey> =
1225            u.toposort_filtered(|node| node.in_scope_ranges().contains(&range_key)).into_iter().map(UOpKey).collect();
1226
1227        // Bail if nested REDUCE or STORE in scope (can't collapse through these)
1228        if in_scope.iter().any(|k| matches!(k.0.op(), Op::Reduce { .. } | Op::Store { .. })) {
1229            return None;
1230        }
1231
1232        // 2. Identify external inputs and substitute with DEFINE_VAR
1233        // (Tinygrad excludes: CONST, VCONST, PARAM, DEFINE_LOCAL, DEFINE_VAR)
1234        let mut replaces: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
1235        for node in &in_scope {
1236            node.0.op().map_child(|child| {
1237                let key = UOpKey(child.clone());
1238                if in_scope.contains(&key) || replaces.contains_key(&key) {
1239                    return;
1240                }
1241                if matches!(
1242                    child.op(),
1243                    Op::Const(_)
1244                        | Op::VConst { .. }
1245                        | Op::DefineVar { .. }
1246                        | Op::Param { device: None, .. }
1247                        | Op::DefineLocal { .. }
1248                ) {
1249                    return;
1250                }
1251                let vmin = match child.vmin() {
1252                    ConstValue::Int(i) => *i,
1253                    ConstValue::UInt(u) => *u as i64,
1254                    ConstValue::Float(f) => *f as i64,
1255                    ConstValue::Bool(b) => *b as i64,
1256                };
1257                let vmax = match child.vmax() {
1258                    ConstValue::Int(i) => *i,
1259                    ConstValue::UInt(u) => *u as i64,
1260                    ConstValue::Float(f) => *f as i64,
1261                    ConstValue::Bool(b) => *b as i64,
1262                };
1263                let var = UOp::define_var(format!("in{}", replaces.len()), vmin, vmax).with_dtype(child.dtype());
1264                replaces.insert(key, var);
1265            });
1266        }
1267
1268        // 3. Build synthetic REDUCE: substituted_body.reduce([range], ADD)
1269        let substituted = u.substitute(&replaces);
1270        let synthetic_reduce = substituted.reduce(smallvec![range.clone()], ReduceOp::Add);
1271
1272        // 4. Apply algebraic patterns to try eliminating the range
1273        let result = crate::rewrite::graph_rewrite(pm, synthetic_reduce, &mut ());
1274
1275        // 5. Check range eliminated (use plain toposort, NOT in_scope_ranges,
1276        //    since REDUCE "ends" ranges and would give a false positive)
1277        let has_range = result.toposort().iter().any(|x| matches!(x.op(), Op::Range { .. }));
1278        if has_range {
1279            return None;
1280        }
1281
1282        // 6. Reverse substitute: DEFINE_VAR → original external inputs
1283        let reverse: HashMap<UOpKey, Arc<UOp>> = replaces.into_iter().map(|(k, v)| (UOpKey(v), k.0)).collect();
1284        u = result.substitute(&reverse);
1285    }
1286
1287    Some(u)
1288}
1289
1290/// Collapse REDUCE using `pm_reduce_collapse` patterns.
1291///
1292/// Tinygrad: `reduce_collapse(red, u)` (uses default `pm=pm_reduce_collapse`).
1293pub fn reduce_collapse(src: &Arc<UOp>, ranges: &[Arc<UOp>]) -> Option<Arc<UOp>> {
1294    reduce_collapse_with(src, ranges, super::patterns::build_reduce_collapse_matcher())
1295}
1296
1297/// Collapse REDUCE using extended `pm_reduce_load_collapse` patterns.
1298///
1299/// Tinygrad: `reduce_load_collapse(red, u)` — calls `reduce_collapse` with
1300/// `pm=pm_reduce_load_collapse` which includes `.or_casted()` variants,
1301/// NE lifting, and the full `pm_load_collapse` non-REDUCE patterns.
1302pub fn reduce_load_collapse(src: &Arc<UOp>, ranges: &[Arc<UOp>]) -> Option<Arc<UOp>> {
1303    reduce_collapse_with(src, ranges, super::patterns::build_reduce_load_collapse_matcher())
1304}
1305
1306pub(crate) fn cast_to_dtype(value: &Arc<UOp>, target_dtype: &morok_dtype::DType) -> Option<Arc<UOp>> {
1307    use morok_dtype::DType;
1308
1309    let scalar_type = match target_dtype {
1310        DType::Scalar(s) => DType::Scalar(*s),
1311        DType::Vector { scalar, .. } => DType::Scalar(*scalar),
1312        _ => return None,
1313    };
1314
1315    let casted = value.cast(scalar_type);
1316
1317    if target_dtype.is_vector() {
1318        let count = target_dtype.count();
1319        let elements: SmallVec<[Arc<UOp>; 4]> = (0..count).map(|_| casted.clone()).collect();
1320        Some(UOp::vectorize(elements))
1321    } else {
1322        Some(casted)
1323    }
1324}
1325
1326// ============================================================================
1327// RANGE SIMPLIFICATION
1328// ============================================================================
1329
1330/// Simplify ranges by merging adjacent ranges to reduce divmod operations.
1331///
1332/// Based on Tinygrad's `pm_simplify_ranges` (simplify.py:20-37).
1333///
1334/// This optimization merges adjacent ranges when the merge reduces the number of
1335/// IDIV and MOD operations in the computation graph. The merged range is then
1336/// decomposed back using divmod to preserve correctness.
1337///
1338/// Key validation from Tinygrad:
1339/// - Both ranges must appear in the same REDUCE operations (consistent scoping)
1340/// - Both ranges must have the same axis type
1341///
1342/// Example:
1343/// - Original: Two ranges R1(16) and R2(8)
1344/// - Merge: Create R_merged(128), decompose as R1 = merged // 8 and R2 = merged % 8
1345/// - Accept: Only if this reduces or maintains the divmod count
1346pub fn simplify_merge_adjacent(u: &Arc<UOp>) -> Option<Arc<UOp>> {
1347    use crate::passes::linearize_index::count_divmod;
1348
1349    // Get ended ranges for this operation
1350    let ended_ranges = match u.op() {
1351        Op::End { computation: _, ranges } => ranges.clone(),
1352        Op::Reduce { ranges, .. } => ranges.clone(),
1353        _ => return None,
1354    };
1355
1356    if ended_ranges.len() < 2 {
1357        return None;
1358    }
1359
1360    // Collect all REDUCE operations in the backward slice (Tinygrad simplify.py:21)
1361    let reduce_ranges: Vec<SmallVec<[Arc<UOp>; 4]>> = u
1362        .toposort()
1363        .iter()
1364        .filter_map(|dep| match dep.op() {
1365            Op::Reduce { ranges, .. } => Some(ranges.clone()),
1366            _ => None,
1367        })
1368        .collect();
1369
1370    // Cumulative merging (Tinygrad simplify.py:37: `u = nidx` inside loop)
1371    // Try all pairs and accumulate successful merges into `current`.
1372    let mut current = Arc::clone(u);
1373    let mut changed = false;
1374
1375    // Re-extract ranges from current for each iteration
1376    let pairs: Vec<(usize, usize)> = if matches!(u.op(), Op::End { .. }) {
1377        (0..ended_ranges.len() - 1).map(|i| (i, i + 1)).collect()
1378    } else {
1379        let mut perms = Vec::new();
1380        for i in 0..ended_ranges.len() {
1381            for j in 0..ended_ranges.len() {
1382                if i != j {
1383                    perms.push((i, j));
1384                }
1385            }
1386        }
1387        perms
1388    };
1389
1390    for (i0, i1) in pairs {
1391        let r0 = &ended_ranges[i0];
1392        let r1 = &ended_ranges[i1];
1393
1394        let (r0_axis_type, r0_end) = match r0.op() {
1395            Op::Range { end, axis_type, .. } => (axis_type, end),
1396            _ => continue,
1397        };
1398        let (r1_axis_type, r1_end) = match r1.op() {
1399            Op::Range { end, axis_type, .. } => (axis_type, end),
1400            _ => continue,
1401        };
1402
1403        if r0_axis_type != r1_axis_type {
1404            continue;
1405        }
1406
1407        // Check same REDUCE scope (Tinygrad simplify.py:25-27)
1408        let valid_reduce_scope = reduce_ranges.iter().all(|rngs| {
1409            let r0_in = rngs.iter().any(|rng| Arc::ptr_eq(rng, r0));
1410            let r1_in = rngs.iter().any(|rng| Arc::ptr_eq(rng, r1));
1411            r0_in == r1_in
1412        });
1413        if !valid_reduce_scope {
1414            continue;
1415        }
1416
1417        if let Some(v) = const_uop_to_i64(r0_end)
1418            && v <= 0
1419        {
1420            continue;
1421        }
1422        if let Some(v) = const_uop_to_i64(r1_end)
1423            && v <= 0
1424        {
1425            continue;
1426        }
1427        if let (Some(s0), Some(s1)) = (const_uop_to_i64(r0_end), const_uop_to_i64(r1_end))
1428            && s0.checked_mul(s1).is_none()
1429        {
1430            continue;
1431        }
1432
1433        let merged_size_uop = r0_end.mul(r1_end);
1434        let merged_range = r0.with_sources(vec![merged_size_uop]);
1435
1436        let new_r0 = merged_range.idiv(r1_end);
1437        let new_r1 = merged_range.mod_(r1_end);
1438
1439        #[allow(clippy::mutable_key_type)]
1440        let mut subs: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
1441        subs.insert(UOpKey(r0.clone()), new_r0);
1442        subs.insert(UOpKey(r1.clone()), new_r1);
1443
1444        // Apply substitution and simplify (Tinygrad simplify.py:30-31)
1445        let rewritten = current.substitute(&subs);
1446        static MERGE_SYM: std::sync::LazyLock<crate::TypedPatternMatcher> =
1447            std::sync::LazyLock::new(|| crate::symbolic::symbolic().clone() + pm_flatten_range().clone());
1448        let simplified = crate::rewrite::graph_rewrite(&*MERGE_SYM, rewritten, &mut ());
1449
1450        // Accept if divmod count is reduced or equal (Tinygrad simplify.py:34-36)
1451        let original_divmod = count_divmod(&current);
1452        let new_divmod = count_divmod(&simplified);
1453
1454        if new_divmod <= original_divmod {
1455            current = simplified;
1456            changed = true;
1457        }
1458    }
1459
1460    if changed { Some(current) } else { None }
1461}
1462
1463/// Pattern matcher for range simplification.
1464///
1465/// Tries to merge adjacent ranges to reduce divmod operations.
1466pub fn pm_simplify_ranges() -> &'static crate::TypedPatternMatcher {
1467    crate::cached_patterns! {
1468        // Match END ops with ranges
1469        u @ End { computation: _, ranges } if !ranges.is_empty() => |u| simplify_merge_adjacent(u),
1470        // Match REDUCE ops with ranges
1471        u @ Reduce { src: _, ranges, reduce_op: _ } if !ranges.is_empty() => |u| simplify_merge_adjacent(u),
1472    }
1473}
1474
1475// ============================================================================
1476// RANGE FLATTENING
1477// ============================================================================
1478
1479/// Flatten nested RANGE operations into canonical form.
1480pub fn flatten_range_impl(r: &Arc<UOp>) -> Option<Arc<UOp>> {
1481    let off = match r.op() {
1482        Op::Reduce { .. } => 1,
1483        Op::Store { .. } => 2, // (index, value, ranges...) - ranges start at index 2
1484        Op::End { .. } => 1,
1485        _ => return None,
1486    };
1487
1488    let original_sources = r.op().sources();
1489    let original_ranges: Vec<&Arc<UOp>> = original_sources.iter().skip(off).collect();
1490    let mut all_range_sources: Vec<Arc<UOp>> = original_ranges.iter().map(|r| (*r).clone()).collect();
1491
1492    let innermost_computation = if matches!(r.op(), Op::End { .. }) {
1493        let mut computation = Arc::clone(&original_sources[0]);
1494
1495        while matches!(computation.op(), Op::End { .. }) {
1496            all_range_sources.extend(computation.op().sources().iter().skip(1).cloned());
1497            computation = Arc::clone(&computation.op().sources()[0]);
1498        }
1499
1500        Some(computation)
1501    } else {
1502        None
1503    };
1504
1505    if all_range_sources.is_empty() {
1506        return None;
1507    }
1508
1509    let sink = UOp::sink(all_range_sources);
1510    let new_ranges: Vec<Arc<UOp>> =
1511        sink.toposort().into_iter().filter(|uop| matches!(uop.op(), Op::Range { .. })).collect();
1512
1513    if new_ranges.is_empty() {
1514        return None;
1515    }
1516
1517    // Check if anything actually changed
1518    if new_ranges.len() == original_ranges.len()
1519        && innermost_computation.as_ref().is_none_or(|c| Arc::ptr_eq(c, &original_sources[0]))
1520        && new_ranges.iter().zip(original_ranges.iter()).all(|(a, b)| Arc::ptr_eq(a, *b))
1521    {
1522        return None; // No change, avoid infinite loop
1523    }
1524
1525    let mut new_sources: Vec<Arc<UOp>> =
1526        if let Some(inner_comp) = innermost_computation { vec![inner_comp] } else { original_sources[..off].to_vec() };
1527    new_sources.extend(new_ranges);
1528
1529    Some(r.with_sources(new_sources))
1530}
1531
1532/// Apply range flattening to a computation graph.
1533#[allow(clippy::mutable_key_type)]
1534pub fn flatten_ranges(root: &Arc<UOp>) -> Arc<UOp> {
1535    let mut replacements: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
1536
1537    for node in root.toposort() {
1538        if let Some(flattened) = flatten_range_impl(&node) {
1539            replacements.insert(UOpKey(node.clone()), flattened);
1540        }
1541    }
1542
1543    root.substitute(&replacements)
1544}
1545
1546// ============================================================================
1547// CYCLE DETECTION
1548// ============================================================================
1549
1550/// Buffer access types for cycle detection.
1551#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1552pub enum OpAccessType {
1553    Load,
1554    Store,
1555}
1556
1557/// Unwrap buffer-like ops to get the underlying buffer.
1558pub fn as_buf(uop: &Arc<UOp>) -> Arc<UOp> {
1559    match uop.op() {
1560        Op::MSelect { buffer, .. } => buffer.clone(),
1561        Op::MStack { buffers } if !buffers.is_empty() => buffers[0].clone(),
1562        Op::After { passthrough, .. } => passthrough.clone(),
1563        _ => uop.clone(),
1564    }
1565}
1566
1567/// Detect conflicting buffer accesses. Panics if same buffer has both LOAD and STORE.
1568#[allow(clippy::mutable_key_type)]
1569pub fn find_bufs(store: &Arc<UOp>) -> HashMap<UOpKey, OpAccessType> {
1570    let mut ret: HashMap<UOpKey, OpAccessType> = HashMap::new();
1571
1572    let nodes = store.toposort_filtered(|uop| !matches!(uop.op(), Op::After { .. }));
1573
1574    for node in nodes {
1575        if let Op::Load { buffer, .. } = node.op() {
1576            let buf = as_buf(buffer);
1577            let buf_key = UOpKey(buf.clone());
1578
1579            if let Some(&existing_access) = ret.get(&buf_key)
1580                && existing_access != OpAccessType::Load
1581            {
1582                panic!(
1583                    "buffer accessed with conflicting ops: {:?} (existing: {:?}, new: {:?})",
1584                    buf,
1585                    existing_access,
1586                    OpAccessType::Load
1587                );
1588            }
1589
1590            ret.insert(buf_key, OpAccessType::Load);
1591        }
1592
1593        if let Some(buffer) = node.store_buffer() {
1594            let buf = as_buf(buffer);
1595            let buf_key = UOpKey(buf.clone());
1596
1597            if let Some(&existing_access) = ret.get(&buf_key)
1598                && existing_access != OpAccessType::Store
1599            {
1600                panic!(
1601                    "buffer accessed with conflicting ops: {:?} (existing: {:?}, new: {:?})",
1602                    buf,
1603                    existing_access,
1604                    OpAccessType::Store
1605                );
1606            }
1607
1608            ret.insert(buf_key, OpAccessType::Store);
1609        }
1610    }
1611
1612    ret
1613}
1614
1615// ============================================================================
1616// PM_ADD_BUFFERS PATTERNS
1617// ============================================================================
1618
1619/// Convert DISK BUFFERIZE(BITCAST|CONTIGUOUS) → BUFFER_VIEW (Tinygrad rangeify.py:285-304).
1620/// For DISK devices, instead of creating a compute kernel, creates a zero-copy typed view
1621/// with byte offset into the memory-mapped file.
1622fn late_buffer_view(compute: &Arc<UOp>, bufferize: &Arc<UOp>) -> Option<Arc<UOp>> {
1623    use morok_ir::uop::cached_property::CachedProperty;
1624    use morok_ir::uop::properties::VminVmaxProperty;
1625
1626    let Op::Bufferize { opts, ranges, .. } = bufferize.op() else { return None };
1627
1628    // Only for DISK device
1629    if !matches!(&opts.device, Some(d) if d.is_disk()) {
1630        return None;
1631    }
1632
1633    // Compute size from ranges (product of range ends)
1634    let size: usize = ranges
1635        .iter()
1636        .map(|r| {
1637            if let Op::Range { end, .. } = r.op()
1638                && let (_, morok_ir::ConstValue::Int(v)) = VminVmaxProperty::get(end)
1639            {
1640                return *v as usize;
1641            }
1642            if let Op::Const(_) = r.op() {
1643                return 1; // const 0 index contributes dim of 1
1644            }
1645            1
1646        })
1647        .product();
1648
1649    // Walk up from compute to find the INDEX node (Tinygrad rangeify.py:291-295)
1650    // In Tinygrad, `t` is the BITCAST/CONTIGUOUS itself. We need to look INTO its children
1651    // for an INDEX, not walk UP past it. The BITCAST's source (after rangeify) should be
1652    // an INDEX or contain one.
1653    let mut x = compute.clone();
1654    loop {
1655        // Check if any SOURCE of x is an INDEX
1656        if x.op().sources().iter().any(|s| matches!(s.op(), Op::Index { .. })) {
1657            break;
1658        }
1659        // For BITCAST/CONTIGUOUS (the starting node), look into their source
1660        if matches!(x.op(), Op::BitCast { .. } | Op::Contiguous { .. }) {
1661            x = x.op().sources().first()?.clone();
1662            continue;
1663        }
1664        // Don't cross other elementwise ops
1665        if matches!(x.op(), Op::Unary(..) | Op::Binary(..) | Op::Ternary(..) | Op::Cast { .. }) {
1666            return None;
1667        }
1668        x = x.op().sources().first()?.clone();
1669    }
1670    let index = x.op().sources().iter().find(|s| matches!(s.op(), Op::Index { .. }))?.clone();
1671
1672    // Compute byte offset (Tinygrad rangeify.py:297-298)
1673    let offset: usize = if let Op::Index { indices, .. } = index.op() {
1674        if indices.is_empty() {
1675            // Scalar: offset from first index's constant arg (Tinygrad: x.src[1].arg)
1676            0
1677        } else {
1678            // Shaped: sum of index vmin values
1679            let mut total: i64 = 0;
1680            for idx in indices.iter() {
1681                let (vmin, _) = VminVmaxProperty::get(idx);
1682                if let morok_ir::ConstValue::Int(v) = vmin {
1683                    total += v;
1684                }
1685            }
1686            total.max(0) as usize
1687        }
1688    } else {
1689        0
1690    };
1691
1692    // Get base buffer (the DISK BUFFER UOp)
1693    let base = index.base();
1694
1695    // Create BUFFER_VIEW with compute's dtype
1696    let buffer_view = UOp::new(Op::BufferView { buffer: base, size, offset }, compute.dtype());
1697
1698    // Replace BUFFERIZE's first source with the BUFFER_VIEW, keep the range source
1699    let new_sources: Vec<Arc<UOp>> = std::iter::once(buffer_view).chain(ranges.iter().cloned()).collect();
1700    Some(UOp::bufferize(new_sources[0].clone(), new_sources[1..].to_vec(), opts.clone()))
1701}
1702
1703/// Create pattern matcher for adding buffers (BUFFERIZE → STORE conversion).
1704///
1705/// Based on Tinygrad's pm_add_buffers (rangeify.py:358-367) with `allow_locals=False`.
1706/// Uses a shared KernelContext (like Tinygrad's `ctx=itertools.count(lunique_start)`)
1707/// to ensure unique buffer IDs across all pattern matches.
1708pub fn pm_add_buffers_patterns() -> crate::TypedPatternMatcher<super::kernel::KernelContext> {
1709    crate::patterns! {
1710        @context super::kernel::KernelContext;
1711        // Flatten multi-range BUFFERIZE to 1D (Tinygrad: flatten_bufferize, rangeify.py:381-389)
1712        buf @ Bufferize { compute: _ } if matches!(buf.op(), Op::Bufferize { ranges, .. } if ranges.len() > 1)
1713            => |buf, _ctx| { flatten_bufferize(buf) },
1714        // pm_mops rule 1: push movement ops through INDEX (Tinygrad rangeify.py:25-26)
1715        Index { buffer: mop, indices, gate } if mop.op().is_movement()
1716            => |mop, indices, gate, _ctx| {
1717                super::patterns::transform_movement_through_index(mop, indices, gate)
1718            },
1719        // pm_mops rule 2: push movement ops through AFTER (Tinygrad rangeify.py:28-29)
1720        // AFTER(MOVEMENT(x, ...), deps) → MOVEMENT(AFTER(x, deps), ...)
1721        After { passthrough: mop, deps } if mop.op().is_movement()
1722            => |mop, deps, _ctx| {
1723                push_movement_through_after(mop, deps)
1724            },
1725        // pm_mops rule 3: strip movement ops from END (Tinygrad rangeify.py:30)
1726        // END(MOVEMENT(x, ...), ranges) → END(x, ranges)
1727        End { computation: mop, ranges } if mop.op().is_movement()
1728            => |mop, ranges, _ctx| {
1729                let src = &mop.op().sources()[0];
1730                Some(src.end(ranges.clone()))
1731            },
1732        // to_bufferview: DISK BUFFERIZE(BITCAST|CONTIGUOUS) → BUFFER_VIEW (Tinygrad rangeify.py:302-304)
1733        buf @ Bufferize { compute }
1734            if matches!(compute.op(), Op::BitCast { .. } | Op::Contiguous { .. })
1735            => |buf, compute, _ctx| late_buffer_view(compute, buf),
1736        // BUFFERIZE → STORE conversion (allow_locals=false: treat local as global)
1737        buf @ Bufferize { compute: _ } => |buf, ctx| {
1738            bufferize_to_store(buf, ctx, false)
1739        },
1740    }
1741}
1742
1743/// Create pattern matcher for adding buffers with local buffer support.
1744///
1745/// Based on Tinygrad's pm_add_buffers_local (rangeify.py:358-367) with `allow_locals=True`.
1746/// Uses a shared KernelContext for unique buffer IDs.
1747pub fn pm_add_buffers_local_patterns() -> crate::TypedPatternMatcher<super::kernel::KernelContext> {
1748    crate::patterns! {
1749        @context super::kernel::KernelContext;
1750        // Flatten multi-range BUFFERIZE to 1D (Tinygrad: flatten_bufferize, rangeify.py:381-389)
1751        buf @ Bufferize { compute: _ } if matches!(buf.op(), Op::Bufferize { ranges, .. } if ranges.len() > 1)
1752            => |buf, _ctx| { flatten_bufferize(buf) },
1753        // pm_mops rule 1: push movement ops through INDEX (Tinygrad rangeify.py:25-26)
1754        Index { buffer: mop, indices, gate } if mop.op().is_movement()
1755            => |mop, indices, gate, _ctx| {
1756                super::patterns::transform_movement_through_index(mop, indices, gate)
1757            },
1758        // pm_mops rule 2: push movement ops through AFTER (Tinygrad rangeify.py:28-29)
1759        After { passthrough: mop, deps } if mop.op().is_movement()
1760            => |mop, deps, _ctx| {
1761                push_movement_through_after(mop, deps)
1762            },
1763        // pm_mops rule 3: strip movement ops from END (Tinygrad rangeify.py:30)
1764        End { computation: mop, ranges } if mop.op().is_movement()
1765            => |mop, ranges, _ctx| {
1766                let src = &mop.op().sources()[0];
1767                Some(src.end(ranges.clone()))
1768            },
1769        // to_bufferview: DISK BUFFERIZE(BITCAST|CONTIGUOUS) → BUFFER_VIEW (Tinygrad rangeify.py:302-304)
1770        buf @ Bufferize { compute }
1771            if matches!(compute.op(), Op::BitCast { .. } | Op::Contiguous { .. })
1772            => |buf, compute, _ctx| late_buffer_view(compute, buf),
1773        // BUFFERIZE → STORE conversion (allow_locals=true: create DEFINE_LOCAL for local addrspace)
1774        buf @ Bufferize { compute: _ } => |buf, ctx| {
1775            bufferize_to_store(buf, ctx, true)
1776        },
1777    }
1778}