Skip to main content

morok_schedule/
devectorize.rs

1//! Devectorize pass — single-pass combined matcher matching Tinygrad's `pm_devectorize`.
2//!
3//! Composition: `sym + devectorize + load_store_folding + correct_load_store + load_store_indexing`
4//!
5//! All patterns run in one `graph_rewrite` call with fixed-point convergence.
6//! PtrCat is an intermediate created by `fold_expanded_index` and eliminated by
7//! `distribute_ptrcat_load/store` within the same pass. It must never reach codegen.
8//!
9//! # pm_render (called AFTER devectorize)
10//!
11//! - CAT → VECTORIZE (CAT can't be rendered directly)
12//! - Multi-index GEP → VECTORIZE of single-index GEPs
13//! - Single-element VECTORIZE/PTRCAT → unwrap
14
15use std::collections::HashMap;
16use std::collections::HashSet;
17use std::sync::Arc;
18use std::sync::LazyLock;
19
20use itertools::Itertools;
21use morok_dtype::{AddrSpace, DType, ScalarDType};
22use morok_ir::{BinaryOp, ConstValue, Op, ReduceOp, TernaryOp, UOp, UOpKey, UnaryOp, WmmaMetadata};
23
24use crate::TypedPatternMatcher;
25use smallvec::SmallVec;
26
27/// Context for REDUCE transformation (Tinygrad devectorizer.py:280-281).
28///
29/// Tracks END nodes created per reduce-range set so that multiple ENDs sharing
30/// the same ranges can be merged into a single END with a GROUP body.
31#[derive(Debug, Default)]
32pub struct ReduceContext {
33    range_to_ends: HashMap<SmallVec<[u64; 4]>, Vec<Arc<UOp>>>,
34}
35
36impl ReduceContext {
37    /// Register an END node under its reduce-range key.
38    pub fn register_end(&mut self, end: &Arc<UOp>) {
39        if let Op::End { ranges, .. } = end.op() {
40            let mut key: SmallVec<[u64; 4]> = ranges.iter().map(|r| r.id).collect();
41            key.sort_unstable();
42            self.range_to_ends.entry(key).or_default().push(end.clone());
43        }
44    }
45
46    /// Merge END nodes that share the same reduce ranges.
47    ///
48    /// For each group of >1 ENDs with identical range sets, creates a single
49    /// `GROUP(computation1, computation2, ...).end(ranges)` and substitutes it
50    /// throughout the SINK subgraph. Clears tracking state after merge.
51    ///
52    /// Matches Tinygrad's `merge_reduce_ends` (devectorizer.py:333-336).
53    pub fn merge_reduce_ends(&mut self, sources: &SmallVec<[Arc<UOp>; 4]>) -> Option<Arc<UOp>> {
54        #[allow(clippy::mutable_key_type)]
55        let subs = build_end_merge_subs(&self.range_to_ends);
56        self.range_to_ends.clear();
57        if subs.is_empty() {
58            return None;
59        }
60        Some(UOp::sink(sources.to_vec()).substitute(&subs))
61    }
62}
63
64/// Core merge logic: given a map of range-key → END nodes, build substitutions.
65#[allow(clippy::mutable_key_type)]
66fn build_end_merge_subs(range_to_ends: &HashMap<SmallVec<[u64; 4]>, Vec<Arc<UOp>>>) -> HashMap<UOpKey, Arc<UOp>> {
67    let mut subs = HashMap::new();
68    for ends in range_to_ends.values() {
69        if ends.len() <= 1 {
70            continue;
71        }
72        let computations: Vec<Arc<UOp>> = ends
73            .iter()
74            .map(|e| match e.op() {
75                Op::End { computation, .. } => computation.clone(),
76                _ => unreachable!(),
77            })
78            .collect();
79        let ranges = match ends[0].op() {
80            Op::End { ranges, .. } => ranges.clone(),
81            _ => unreachable!(),
82        };
83        let merged = UOp::group(computations).end(ranges);
84        for end in ends {
85            subs.insert(UOpKey(end.clone()), merged.clone());
86        }
87    }
88    subs
89}
90
91/// Merge sibling END nodes that share the same reduce ranges (standalone pass).
92///
93/// Walks the SINK subgraph, discovers all END nodes, groups by range key,
94/// and merges groups of >1 into `GROUP(computations...).end(ranges)`.
95///
96/// This is the same merge as `ReduceContext::merge_reduce_ends` but doesn't
97/// require tracking during pm_reduce — it discovers ENDs from the graph directly.
98/// Needed because later passes (e.g. pm_decomp+pm_render) can create new sibling
99/// ENDs that weren't present when pm_reduce ran.
100pub fn merge_sibling_ends(sink: &Arc<UOp>) -> Arc<UOp> {
101    let Op::Sink { sources } = sink.op() else { return sink.clone() };
102
103    let mut range_to_ends: HashMap<SmallVec<[u64; 4]>, Vec<Arc<UOp>>> = HashMap::new();
104    for node in sink.toposort() {
105        if let Op::End { ranges, .. } = node.op() {
106            let mut key: SmallVec<[u64; 4]> = ranges.iter().map(|r| r.id).collect();
107            key.sort_unstable();
108            range_to_ends.entry(key).or_default().push(node.clone());
109        }
110    }
111
112    #[allow(clippy::mutable_key_type)]
113    let subs = build_end_merge_subs(&range_to_ends);
114    if subs.is_empty() {
115        return sink.clone();
116    }
117    UOp::sink(sources.to_vec()).substitute(&subs)
118}
119
120use crate::rewrite::graph_rewrite;
121use crate::symbolic::patterns::sym;
122
123// ============================================================================
124// Main Entry Point
125// ============================================================================
126
127/// Run devectorize pass. Call AFTER `pre_expand`, BEFORE codegen.
128///
129/// Single-pass combined matcher (Tinygrad `pm_devectorize`):
130/// sym + devectorize + load_store_folding + correct_load_store + load_store_indexing
131///
132/// Note: `bool_storage_patterns()` called separately (backend-specific).
133/// Note: `pm_render()` should be applied AFTER this pass.
134pub fn devectorize(ast: &Arc<UOp>) -> Arc<UOp> {
135    static COMBINED: LazyLock<TypedPatternMatcher> = LazyLock::new(|| {
136        sym()
137            + devectorize_patterns()
138            + load_store_folding_patterns()
139            + correct_load_store_patterns()
140            + load_store_indexing_patterns()
141    });
142    graph_rewrite(&*COMBINED, ast.clone(), &mut ())
143}
144
145/// Bool LOAD/STORE via uint8. LLVM i1 can have garbage in upper bits.
146/// Also rewrites BitCast involving Bool to Cast (bitcast requires same bit-width).
147pub fn bool_storage_patterns() -> &'static TypedPatternMatcher {
148    crate::cached_patterns! {
149        // STORE bool: cast to uint8 before storing
150        Store { index, value, ranges } if value.dtype().base().is_bool() => {
151            let uint8_dtype = value.dtype().with_base(ScalarDType::UInt8);
152            Some(index.store_with_ranges(value.cast(uint8_dtype), ranges.clone()))
153        },
154
155        // LOAD bool: load as uint8, then cast to bool
156        load @ Load { buffer, index } if load.dtype().base().is_bool() => {
157            let uint8_dtype = load.dtype().with_base(ScalarDType::UInt8);
158            let uint8_load = UOp::load().buffer(buffer.clone()).index(index.clone()).dtype(uint8_dtype).call();
159            Some(uint8_load.cast(load.dtype()))
160        },
161
162        // BitCast with Bool: i1 has different bit-width than i8+, use Cast instead
163        BitCast { src, dtype } if src.dtype().base().is_bool() || dtype.base().is_bool() => {
164            Some(src.cast(dtype.clone()))
165        },
166    }
167}
168
169// ============================================================================
170// FP8 Float Decomposition (Tinygrad: pm_float_decomp, decompositions.py:504-522)
171// ============================================================================
172
173/// Context for FP8 float decomposition.
174/// `from` is the FP8 dtype being decomposed, `to` is the target float dtype.
175#[derive(Debug, Clone)]
176pub struct Fp8DecompCtx {
177    pub from: ScalarDType,
178    pub to: ScalarDType,
179}
180
181/// Round-to-nearest-even for integer bitwise rounding.
182/// Port of Tinygrad's `rne(v, s)` (decompositions.py:383).
183fn rne(v: &Arc<UOp>, s: u32) -> Arc<UOp> {
184    let one = v.const_like(1);
185    let shifted = v.shr(&v.const_like(s));
186    let half_bit = v.shr(&v.const_like(s - 1)).and_(&one);
187    let remainder_mask = v.const_like((1i64 << (s - 1)) - 1);
188    let has_remainder = v.and_(&remainder_mask).ne(&v.const_like(0)).cast(v.dtype());
189    let lsb = shifted.and_(&one);
190    let round_up = half_bit.and_(&has_remainder.or_(&lsb));
191    shifted.try_add(&round_up).expect("rne: add failed")
192}
193
194/// Bitwise float-to-float format conversion.
195/// Port of Tinygrad's `f2f(v, fr, to)` (decompositions.py:385-404).
196///
197/// `v` is a UInt value holding the raw bits of the source float.
198/// Returns a UOp holding raw bits of the target float, which must be bitcast to get the float value.
199fn f2f(v: &Arc<UOp>, fr: ScalarDType, to: ScalarDType) -> Arc<UOp> {
200    let (fe, fm) = fr.finfo();
201    let (te, tm) = to.finfo();
202    let fs = fr.bitsize();
203    let ts = to.bitsize();
204    let fb = fr.exponent_bias() as i64;
205    let tb = to.exponent_bias() as i64;
206    let fr_uint = DType::Scalar(fr.float_to_uint());
207    let to_uint = DType::Scalar(to.float_to_uint());
208
209    if fe <= te && fm < tm {
210        // Upcast path: e.g. FP8 → Float16
211        let sign_mask = v.const_like(1i64 << (fs - 1));
212        let sign = v.and_(&sign_mask).cast(to_uint.clone()).shl(&v.const_like(ts - fs).cast(to_uint.clone()));
213        let nosign_mask = v.const_like((1i64 << (fs - 1)) - 1);
214        let nosign = v.and_(&nosign_mask).cast(to_uint.clone());
215        let exp = nosign.shr(&nosign.const_like(fm));
216        let norm = nosign
217            .shl(&nosign.const_like(tm - fm))
218            .try_add(&nosign.const_like((tb - fb) << tm))
219            .expect("f2f: add failed");
220        let nan_val = nosign.shl(&nosign.const_like(tm - fm)).or_(&nosign.const_like(((1i64 << te) - 1) << tm));
221
222        // FP8E4M3 has a single NaN value (all exponent+mantissa bits set)
223        let is_nan = if fr == ScalarDType::FP8E4M3 {
224            nosign.eq(&nosign.const_like((1i64 << (fm + fe)) - 1))
225        } else {
226            exp.eq(&exp.const_like((1i64 << fe) - 1))
227        };
228
229        let zero = nosign.const_like(0);
230        let exp_is_zero = exp.eq(&zero);
231        let inner = UOp::try_where(is_nan, nan_val, norm).expect("f2f: where failed");
232        let result = UOp::try_where(exp_is_zero, zero, inner).expect("f2f: where failed");
233        sign.or_(&result).bitcast(DType::Scalar(to))
234    } else if fe >= te && fm > tm {
235        // Downcast path: e.g. Float16 → FP8
236        let clamped = f2f_clamp(&v.bitcast(DType::Scalar(fr)), to);
237        let v = clamped.bitcast(fr_uint);
238        let sign = v.shr(&v.const_like(fs - ts)).and_(&v.const_like(1i64 << (ts - 1)));
239        let nosign_mask = v.const_like((1i64 << (fs - 1)) - 1);
240        let nosign = v.and_(&nosign_mask);
241        let norm = rne(&nosign, fm - tm)
242            .try_sub(&nosign.const_like((fb - tb) << tm))
243            .expect("f2f: sub failed")
244            .cast(to_uint.clone());
245
246        let exp_field = nosign.shr(&nosign.const_like(fm)).and_(&nosign.const_like((1i64 << fe) - 1));
247        let underflow = exp_field.lt(&exp_field.const_like(1 + fb - tb));
248
249        let nan_mantissa = if to == ScalarDType::FP8E4M3 {
250            sign.const_like((1i64 << tm) - 1).cast(to_uint.clone())
251        } else {
252            nosign.shr(&nosign.const_like(fm - tm)).and_(&nosign.const_like((1i64 << tm) - 1)).cast(to_uint.clone())
253        };
254        let nan_exp = sign.const_like(((1i64 << te) - 1) << tm).cast(to_uint.clone());
255        let nan = sign.cast(to_uint.clone()).or_(&nan_mantissa).or_(&nan_exp);
256
257        let is_nan = exp_field.eq(&exp_field.const_like((1i64 << fe) - 1));
258        let zero = sign.const_like(0).cast(to_uint.clone());
259        let normal = sign.cast(to_uint.clone()).or_(&UOp::try_where(underflow, zero, norm).expect("f2f: where failed"));
260        UOp::try_where(is_nan, nan, normal).expect("f2f: where failed")
261    } else {
262        panic!("f2f: unsupported conversion {fr:?} -> {to:?}")
263    }
264}
265
266/// Clamp a float value to the representable range of a target FP8 dtype.
267/// Port of Tinygrad's `f2f_clamp` (decompositions.py:406-412).
268fn f2f_clamp(val: &Arc<UOp>, dt: ScalarDType) -> Arc<UOp> {
269    let (e, m) = dt.finfo();
270    let (max_exp, max_man): (i64, i64) =
271        if dt == ScalarDType::FP8E4M3 { ((1 << e) - 1, (1 << m) - 2) } else { ((1 << e) - 2, (1 << m) - 1) };
272    let mx_f64 =
273        f64::powi(2.0, (max_exp - dt.exponent_bias() as i64) as i32) * (1.0 + max_man as f64 / (1i64 << m) as f64);
274    let mx = val.const_like(mx_f64);
275    let neg_mx = val.const_like(-mx_f64);
276
277    // For FP8 types, clamp to ±max; for others, clamp to ±inf
278    let sat = if dt.is_fp8() { mx.clone() } else { val.const_like(f64::INFINITY) };
279    let neg_sat = if dt.is_fp8() { neg_mx.clone() } else { val.const_like(f64::NEG_INFINITY) };
280
281    // nan → nan, < -mx → -sat, > mx → sat, otherwise → val
282    let is_nan = val.ne(val);
283    let below = val.lt(&neg_mx);
284    let above = mx.lt(val);
285    let clamped_above = UOp::try_where(above, sat, val.clone()).expect("f2f_clamp: where failed");
286    let clamped = UOp::try_where(below, neg_sat, clamped_above).expect("f2f_clamp: where failed");
287    UOp::try_where(is_nan, val.clone(), clamped).expect("f2f_clamp: where failed")
288}
289
290/// FP8 STORE decomposition patterns (bpm — sees ORIGINAL children).
291///
292/// The STORE pattern must run in the bpm slot so it sees the ORIGINAL index dtype
293/// (still FP8) before Pattern 1 changes it to UInt8. This is the Morok equivalent
294/// of Tinygrad's `tag` mechanism in `pm_float_decomp`.
295pub fn pm_float_decomp_store() -> crate::TypedPatternMatcher<Fp8DecompCtx> {
296    crate::patterns! {
297        @context Fp8DecompCtx;
298
299        // STORE to FP8 buffer → f2f convert value→UInt8, store
300        // In bpm, index still has FP8 ptr (ORIGINAL children, before Pattern 1 runs).
301        Store { index, value, ranges }
302            if index.dtype().base() == ctx.from
303        => {
304            let target_float = DType::Scalar(ctx.to);
305            let target_uint = DType::Scalar(ctx.to.float_to_uint());
306            // Cast value to target float (handles FP8, Float32, etc. → Float16)
307            let float_val = value.cast(target_float);
308            // Bitwise float→FP8 conversion (includes clamping internally)
309            let result = f2f(&float_val.bitcast(target_uint), ctx.to, ctx.from);
310            // Change index ptr to UInt8
311            let uint8_ptr = index.dtype().with_ptr_base(DType::Scalar(ctx.from.float_to_uint()))?;
312            let new_index = index.with_dtype(uint8_ptr);
313            Some(new_index.store_with_ranges(result, ranges.clone()))
314        },
315    }
316}
317
318/// FP8 float decomposition patterns (pm — sees OPTIMIZED children).
319///
320/// Port of Tinygrad's `pm_float_decomp` (decompositions.py:504-522).
321/// Run via `graph_rewrite_with_bpm` together with `pm_float_decomp_store()`.
322pub fn pm_float_decomp() -> crate::TypedPatternMatcher<Fp8DecompCtx> {
323    crate::patterns! {
324        @context Fp8DecompCtx;
325
326        // Pattern 1: INDEX/DEFINE with FP8 ptr base → change ptr to UInt8
327        x if matches!(x.op(), Op::Param { device: None, .. } | Op::DefineLocal(_) | Op::Index { .. })
328            && x.dtype().base() == ctx.from
329        => {
330            let uint8_ptr = x.dtype().with_ptr_base(DType::Scalar(ctx.from.float_to_uint()))?;
331            Some(x.with_dtype(uint8_ptr))
332        },
333
334        // Pattern 2: LOAD with FP8 dtype → load as UInt8, f2f upcast to target float
335        load @ Load { buffer, index } if load.dtype().base() == ctx.from => {
336            let uint_dtype = DType::Scalar(ctx.from.float_to_uint());
337            let uint_load = UOp::load().buffer(buffer.clone()).index(index.clone())
338                .dtype(uint_dtype).call();
339            Some(f2f(&uint_load, ctx.from, ctx.to))
340        },
341
342        // Pattern 5: CAST to FP8 → full round-trip (Float16→FP8 bytes→Float16).
343        // Must do the complete conversion (not just clamp) because the kernel may fuse
344        // Cast(Float16→FP8) and Cast(FP8→Float32) without materializing the FP8 buffer.
345        x @ Cast { src: val, .. } if x.dtype().base() == ctx.from => {
346            let target = DType::Scalar(ctx.to);
347            let target_uint = DType::Scalar(ctx.to.float_to_uint());
348            let float_val = val.cast(target);
349            // Downcast: Float16 bits → FP8 bytes (includes clamping)
350            let fp8_bytes = f2f(&float_val.bitcast(target_uint.clone()), ctx.to, ctx.from);
351            // Upcast: FP8 bytes → Float16 (proper FP8-quantized value)
352            Some(f2f(&fp8_bytes, ctx.from, ctx.to))
353        },
354
355        // Pattern 6: Any op with FP8 output dtype → promote to target float, cast FP8 sources
356        x if !matches!(x.op(), Op::BitCast { .. })
357            && x.dtype().is_float()
358            && x.dtype().base() == ctx.from
359        => {
360            let target_dtype = DType::Scalar(ctx.to);
361            let new_dtype = if x.dtype().vcount() > 1 {
362                target_dtype.vec(x.dtype().vcount())
363            } else {
364                target_dtype.clone()
365            };
366            let new_sources: Vec<Arc<UOp>> = x.op().sources().iter().map(|s| {
367                if s.dtype().base() == ctx.from {
368                    s.cast(target_dtype.clone())
369                } else {
370                    s.clone()
371                }
372            }).collect();
373            Some(x.with_sources(new_sources).with_dtype(new_dtype))
374        },
375    }
376}
377
378/// Post-devectorize rendering patterns (devectorizer.py:258-275).
379/// Called during codegen, NOT part of pm_devectorize.
380pub fn pm_render() -> &'static TypedPatternMatcher {
381    crate::cached_patterns! {
382        // Vector CONST → VECTORIZE of scalar CONST (devectorizer.py:260-261)
383        c @ Const(_) if c.dtype().vcount() > 1 => |c| {
384            let vcount = c.dtype().vcount();
385            let Op::Const(cv) = c.op() else { return None };
386            let scalar_const = UOp::const_(c.dtype().scalar_dtype(), cv.0);
387            let elements: SmallVec<[Arc<UOp>; 4]> = (0..vcount).map(|_| scalar_const.clone()).collect();
388            Some(UOp::vectorize(elements))
389        },
390
391        // VCONST → VECTORIZE of scalar CONSTs (devectorizer.py:262)
392        vc @ VConst { values } => |vc, values| {
393            let scalar_dtype = vc.dtype().scalar_dtype();
394            let elements: SmallVec<[Arc<UOp>; 4]> = values.iter()
395                .map(|v| UOp::const_(scalar_dtype.clone(), *v))
396                .collect();
397            Some(UOp::vectorize(elements))
398        },
399
400        // CAT → VECTORIZE (CAT can't be rendered)
401        Cat { sources } if sources.len() == 1 => Some(sources[0].clone()),
402        Cat { sources } => {
403            let elements: SmallVec<[Arc<UOp>; 4]> = sources.iter()
404                .flat_map(|src| {
405                    let n = src.dtype().vcount();
406                    (0..n).map(move |i| if n == 1 { src.clone() } else { src.gep(vec![i]) })
407                })
408                .collect();
409            Some(UOp::vectorize(elements))
410        },
411
412        // GEP on scalar → identity
413        Gep { vector, indices } if vector.dtype().vcount() == 1 && indices.len() == 1 && indices[0] == 0
414            ~> |vector| Arc::clone(vector),
415
416        // GEP identity: [0,1,...,n-1] → unwrap (must be before multi-index GEP)
417        Gep { vector, indices } if is_identity_gep(vector, indices) => Some(vector.clone()),
418
419        // GEP(VECTORIZE) → extract (must be before multi-index GEP)
420        Gep { vector: Vectorize { elements }, indices } => {
421            let extracted: SmallVec<[Arc<UOp>; 4]> = indices.iter()
422                .filter_map(|&i| elements.get(i).cloned())
423                .collect();
424            if extracted.len() != indices.len() { return None; }
425            Some(if extracted.len() == 1 { extracted[0].clone() } else { UOp::vectorize(extracted) })
426        },
427
428        // Multi-index GEP → VECTORIZE (fallback, must be last GEP pattern)
429        Gep { vector, indices } if indices.len() > 1 => |vector, indices| {
430            let geps: SmallVec<[Arc<UOp>; 4]> = indices.iter()
431                .map(|&i| vector.gep(vec![i]))
432                .collect();
433            Some(UOp::vectorize(geps))
434        },
435
436        // Single-element unwrap
437        Vectorize { elements } if elements.len() == 1 => Some(elements[0].clone()),
438        PtrCat { sources } if sources.len() == 1 => Some(sources[0].clone()),
439
440        // =========================================================================
441        // Gated Load Alt Patterns (devectorizer.py:266-274)
442        // =========================================================================
443
444        // Give any gated LOADs without alt a const 0 alt value (devectorizer.py:267-269)
445        // LOAD(INDEX(buf, idx, gate)) where alt is None → LOAD with alt=0
446        load @ Load { index, alt: None, .. } if has_gate(index) => |load, index| {
447            let alt_value = load.const_like(ConstValue::Int(0));
448            Some(UOp::load().buffer(load.load_buffer()?).index(index.clone()).alt(alt_value).dtype(load.dtype()).call())
449        },
450
451        // WHERE(c, LOAD(INDEX(buf, idx, c)), alt) → LOAD with alt value (devectorizer.py:289-291)
452        // The load's gate matches the WHERE condition.
453        // Matches Tinygrad's allow_any_len=True (no alt: None guard).
454        // NOTE: if alt is CAST and alt.src.dtype == load.dtype, use alt.src to avoid
455        // roundtrip cast (e.g. uint->float->uint).
456        Where(cond, load @ Load { index, .. }, alt)
457            if index_has_gate_matching(index, cond)
458            => |cond, load, index, alt| {
459                let casted_alt = cast_alt_avoiding_roundtrip(alt, &load.dtype());
460                let new_load = UOp::load()
461                    .buffer(load.load_buffer()?)
462                    .index(index.clone())
463                    .alt(casted_alt)
464                    .dtype(load.dtype())
465                    .call();
466                Some(new_load.cast(alt.dtype()))
467            },
468
469        // WHERE(c, alt, LOAD(INDEX(buf, idx, !c))) → LOAD with alt value (devectorizer.py:292-294)
470        // Same pattern but with inverted condition in WHERE.
471        // is_negation_of handles NOT(cond) and pm_comparison_negations simplified forms.
472        Where(cond, alt, load @ Load { index, .. })
473            if index_has_inverted_gate_matching(index, cond)
474            => |cond, alt, load, index| {
475                let casted_alt = cast_alt_avoiding_roundtrip(alt, &load.dtype());
476                let new_load = UOp::load()
477                    .buffer(load.load_buffer()?)
478                    .index(index.clone())
479                    .alt(casted_alt)
480                    .dtype(load.dtype())
481                    .call();
482                Some(new_load.cast(alt.dtype()))
483            },
484    }
485}
486
487/// Cast alt value to load dtype, avoiding roundtrip casts (devectorizer.py:290).
488/// If alt is CAST(inner) and inner.dtype == load_dtype, use inner directly
489/// to avoid e.g. uint→float→uint.
490fn cast_alt_avoiding_roundtrip(alt: &Arc<UOp>, load_dtype: &DType) -> Arc<UOp> {
491    if let Op::Cast { src: inner, .. } = alt.op()
492        && inner.dtype() == *load_dtype
493    {
494        return inner.clone();
495    }
496    alt.cast(load_dtype.clone())
497}
498
499/// Check if GEP is identity: GEP(x, [0,1,...,n-1]) where n == x.vcount
500fn is_identity_gep(vector: &Arc<UOp>, indices: &[usize]) -> bool {
501    let vcount = vector.dtype().vcount();
502    indices.len() == vcount && indices.iter().enumerate().all(|(i, &j)| i == j)
503}
504
505/// Check if index (or casted index) has a gate.
506fn has_gate(index: &Arc<UOp>) -> bool {
507    match index.op() {
508        Op::Index { gate: Some(_), .. } => true,
509        Op::Cast { src, .. } => has_gate(src),
510        _ => false,
511    }
512}
513
514/// Check if index has a gate that matches the given condition (pointer equality).
515fn index_has_gate_matching(index: &Arc<UOp>, cond: &Arc<UOp>) -> bool {
516    match index.op() {
517        Op::Index { gate: Some(g), .. } => Arc::ptr_eq(g, cond),
518        Op::Cast { src, .. } => index_has_gate_matching(src, cond),
519        _ => false,
520    }
521}
522
523/// Check if index has an inverted gate that matches the given condition.
524///
525/// Matches INDEX gate that is semantically NOT(cond). Handles three forms:
526/// 1. `NOT(cond)` — structural NOT, pointer-equal inner
527/// 2. `Lt(c-1, x)` when cond = `Lt(x, c)` — result of pm_comparison_negations on NOT(Lt(x,c))
528/// 3. `Lt(x, c+1)` when cond = `Lt(c, x)` — result of pm_comparison_negations on NOT(Lt(c,x))
529///
530/// Form 2/3 arise because dce_dsl_patterns swaps WHERE(NOT(Lt), t, f) → WHERE(Lt, f, t),
531/// then pm_comparison_negations converts the NOT(Lt) on the INDEX gate to a reversed Lt.
532fn index_has_inverted_gate_matching(index: &Arc<UOp>, cond: &Arc<UOp>) -> bool {
533    match index.op() {
534        Op::Index { gate: Some(g), .. } => is_negation_of(g, cond),
535        Op::Cast { src, .. } => index_has_inverted_gate_matching(src, cond),
536        _ => false,
537    }
538}
539
540/// Check if `gate` is semantically NOT(cond).
541fn is_negation_of(gate: &Arc<UOp>, cond: &Arc<UOp>) -> bool {
542    // Form 1: NOT(cond) — structural
543    if let Op::Unary(UnaryOp::Not, inner) = gate.op()
544        && Arc::ptr_eq(inner, cond)
545    {
546        return true;
547    }
548
549    // Form 2: gate = Lt(c-1, x), cond = Lt(x, c) — from pm_comparison_negations on NOT(Lt(x,c))
550    // Form 3: gate = Lt(x, c+1), cond = Lt(c, x) — from pm_comparison_negations on NOT(Lt(c,x))
551    if let Op::Binary(BinaryOp::Lt, gate_lhs, gate_rhs) = gate.op()
552        && let Op::Binary(BinaryOp::Lt, cond_lhs, cond_rhs) = cond.op()
553    {
554        // Form 2: gate_rhs == cond_lhs (same x), gate_lhs == c-1 where cond_rhs == c
555        if Arc::ptr_eq(gate_rhs, cond_lhs)
556            && let (Op::Const(gate_cv), Op::Const(cond_cv)) = (gate_lhs.op(), cond_rhs.op())
557            && is_const_minus_one(&cond_cv.0, &gate_cv.0)
558        {
559            return true;
560        }
561        // Form 3: gate_lhs == cond_rhs (same x), gate_rhs == c+1 where cond_lhs == c
562        if Arc::ptr_eq(gate_lhs, cond_rhs)
563            && let (Op::Const(cond_cv), Op::Const(gate_cv)) = (cond_lhs.op(), gate_rhs.op())
564            && is_const_minus_one(&cond_cv.0, &gate_cv.0)
565        {
566            return true;
567        }
568    }
569
570    false
571}
572
573/// Check if `a - 1 == b` (i.e., b = a - 1).
574fn is_const_minus_one(a: &ConstValue, b: &ConstValue) -> bool {
575    match (a, b) {
576        (ConstValue::Int(av), ConstValue::Int(bv)) => av.checked_sub(1) == Some(*bv),
577        (ConstValue::UInt(av), ConstValue::UInt(bv)) => av.checked_sub(1) == Some(*bv),
578        _ => false,
579    }
580}
581
582// ============================================================================
583// ALU Devectorization
584// ============================================================================
585
586/// Generic ALU devectorization: Vector ALU → VECTORIZE of scalar ALU.
587///
588/// Mirrors Tinygrad's `no_vectorized_alu` (devectorizer.py:219-223):
589/// ```python
590/// alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg)
591///              for i in range(alu.dtype.vcount))
592/// return UOp(Ops.VECTORIZE, alu.dtype, alus)
593/// ```
594fn devectorize_alu(alu: &Arc<UOp>) -> Option<Arc<UOp>> {
595    let vcount = alu.dtype().vcount();
596    if vcount <= 1 {
597        return None;
598    }
599
600    // Skip WHERE(cond, t, Invalid) - used for image indexing (devectorizer.py:232)
601    // Handles both scalar Invalid and vectorized VECTORIZE(Invalid,...) from expansion.
602    if let Op::Ternary(TernaryOp::Where, _, _, f) = alu.op()
603        && UOp::is_invalid_marker(f)
604    {
605        return None;
606    }
607
608    let scalar_dtype = alu.dtype().scalar_dtype();
609    let sources = alu.op().sources();
610
611    let elements: SmallVec<[Arc<UOp>; 4]> = (0..vcount)
612        .map(|i| {
613            // Apply GEP to each source, broadcasting scalars
614            let new_sources: Vec<Arc<UOp>> =
615                sources.iter().map(|s| if s.dtype().vcount() > 1 { s.gep(vec![i]) } else { s.clone() }).collect();
616
617            // CAST and BITCAST need special handling: Op::Cast/BitCast has its own dtype field
618            // that must be updated to scalar, not just the UOp's result dtype.
619            // The generic replace chain doesn't update Op::Cast::dtype.
620            match alu.op() {
621                Op::Cast { .. } => new_sources[0].cast(scalar_dtype.clone()),
622                Op::BitCast { .. } => new_sources[0].bitcast(scalar_dtype.clone()),
623                _ => alu.replace().dtype(scalar_dtype.clone()).src(new_sources).call(),
624            }
625        })
626        .collect();
627
628    Some(UOp::vectorize(elements))
629}
630
631/// Vector ALU → VECTORIZE of scalar ALU (devectorizer.py:219-223).
632/// LLVM SLP can re-vectorize when beneficial.
633#[allow(unused_variables)]
634pub fn no_vectorized_alu() -> &'static TypedPatternMatcher {
635    crate::cached_patterns! {
636        // All binary ops
637        for op in binary [*] {
638            alu @ op(_, _) if alu.dtype().vcount() > 1 => devectorize_alu(alu),
639        },
640        // All unary ops
641        for op in unary [*] {
642            alu @ op(_) if alu.dtype().vcount() > 1 => devectorize_alu(alu),
643        },
644        // All ternary ops (Where, MulAcc)
645        for op in ternary [*] {
646            alu @ op(_, _, _) if alu.dtype().vcount() > 1 => devectorize_alu(alu),
647        },
648        // Cast and BitCast
649        alu @ Cast { src: _, .. } if alu.dtype().vcount() > 1 => devectorize_alu(alu),
650        alu @ BitCast { src: _, .. } if alu.dtype().vcount() > 1 => devectorize_alu(alu),
651    }
652}
653
654// ============================================================================
655// Devectorize Patterns (devectorizer.py:250-256)
656// ============================================================================
657
658/// Combined devectorize patterns: cast_after, ALU, WMMA, buffer/index devectorization.
659pub fn devectorize_patterns() -> &'static TypedPatternMatcher {
660    use std::sync::LazyLock;
661    static CACHED: LazyLock<TypedPatternMatcher> = LazyLock::new(|| {
662        cast_after_pattern() + no_vectorized_alu() + no_vectorized_wmma() + devectorize_buf_and_index_patterns()
663    });
664    &CACHED
665}
666
667/// WMMA devectorization (devectorizer.py:208-217).
668fn no_vectorized_wmma() -> &'static TypedPatternMatcher {
669    crate::cached_patterns! {
670        wmma @ Wmma { a, b, c, metadata } if wmma.dtype().vcount() > wmma_expected_size(metadata)
671            => devectorize_wmma(wmma, a, b, c, metadata),
672    }
673}
674
675fn wmma_expected_size(metadata: &WmmaMetadata) -> usize {
676    metadata.upcast_axes.c.iter().map(|(_, size)| size).product::<usize>().max(1)
677}
678
679fn devectorize_wmma(
680    wmma: &Arc<UOp>,
681    a: &Arc<UOp>,
682    b: &Arc<UOp>,
683    c: &Arc<UOp>,
684    metadata: &WmmaMetadata,
685) -> Option<Arc<UOp>> {
686    let out_sz = wmma_expected_size(metadata);
687    if wmma.dtype().vcount() == out_sz {
688        return None;
689    }
690
691    // Split each source by its OWN axis sizes (A, B, C may differ).
692    // For CUDA 8-16-16 with elements_per_thread=(8,4,4):
693    //   A split by 8, B split by 4, C split by 4.
694    let sources: [&Arc<UOp>; 3] = [a, b, c];
695    let tsrcs: Vec<Vec<Arc<UOp>>> = sources
696        .iter()
697        .enumerate()
698        .map(|(i, src)| {
699            let ssz = metadata.upcast_axes.source_size(i);
700            let n = src.dtype().vcount();
701            (0..n).step_by(ssz).map(|g| src.gep((g..g + ssz.min(n - g)).collect())).collect()
702        })
703        .collect();
704
705    // Verify all sources have same number of groups
706    let num_groups = tsrcs[0].len();
707    if tsrcs.iter().any(|t| t.len() != num_groups) {
708        tracing::warn!("WMMA devectorization: mismatched source group counts");
709        return None;
710    }
711
712    // Create new WMMA for each group, flatten with GEP
713    let wmma_ex: SmallVec<[Arc<UOp>; 4]> = (0..num_groups)
714        .flat_map(|g| {
715            let w = UOp::wmma(tsrcs[0][g].clone(), tsrcs[1][g].clone(), tsrcs[2][g].clone(), metadata.clone());
716            (0..out_sz).map(move |i| w.gep(vec![i]))
717        })
718        .collect();
719
720    Some(UOp::vectorize(wmma_ex))
721}
722
723/// AFTER(CAST(x), deps) → CAST(AFTER(x, deps)) - allows cast to be optimized independently.
724fn cast_after_pattern() -> &'static TypedPatternMatcher {
725    crate::cached_patterns! {
726        After { passthrough: Cast { src, dtype }, deps }
727            => |src, dtype, deps| {
728                let new_after = src.after(deps.clone());
729                Some(new_after.cast(dtype.clone()))
730            },
731    }
732}
733
734/// LOCAL/REG buffer devectorization (devectorizer.py:241-248).
735///
736/// Extended beyond Tinygrad: handles vector indices (not just scalar) by expanding
737/// each index lane separately. Tinygrad asserts `idx.dtype.count == 1` which would
738/// crash on local buffers with vector indices from UPCAST — Morok's optimizer can
739/// produce such kernels (e.g., u3u3 upcast on matmul with local buffers).
740fn devectorize_buf_and_index_patterns() -> &'static TypedPatternMatcher {
741    crate::cached_patterns! {
742        // DEFINE_LOCAL/REG with vector pointer → scalar pointer + CAST
743        def if matches!(def.op(), Op::DefineLocal(_) | Op::DefineReg { .. })
744            && def.ptrdtype().is_some_and(|(base, _, _)| base.vcount() > 1)
745            => no_vectorized_buf(def),
746
747        // INDEX(CAST(DEFINE_LOCAL/REG), idx) → scaled vector index
748        // Handles both scalar and vector idx (Tinygrad only handles scalar).
749        Index { buffer: Cast { src: buf, dtype: cast_dtype }, indices, gate }
750            if is_vectorized_local_reg_cast(buf, cast_dtype)
751            => no_vectorized_index(buf, indices, gate, cast_dtype),
752
753        // INDEX(BROADCAST(CAST(...)), idx)
754        Index { buffer: Vectorize { elements }, indices, gate }
755            if is_vectorized_broadcast_cast(elements)
756            => {
757                let first = elements.first()?;
758                let Op::Cast { src: buf, dtype: DType::Ptr { base, .. } } = first.op() else { return None };
759                let idx = indices.first()?;
760                no_vectorized_index_precnt(buf, idx, gate, base.vcount(), &vec![0; elements.len()])
761            },
762
763        // INDEX(GEP(CAST(...)), idx)
764        Index { buffer: Gep { vector: Cast { src: buf, dtype: cast_dtype }, indices: gep_indices }, indices, gate }
765            if is_vectorized_local_reg_cast(buf, cast_dtype)
766            => {
767                let DType::Ptr { base, .. } = cast_dtype else { return None };
768                let idx = indices.first()?;
769                no_vectorized_index_precnt(buf, idx, gate, base.vcount(), gep_indices)
770            },
771    }
772}
773
774fn is_vectorized_local_reg_cast(buf: &Arc<UOp>, cast_dtype: &DType) -> bool {
775    matches!(cast_dtype, DType::Ptr { base, .. } if base.vcount() > 1) && is_define_local_or_reg_or_after(buf)
776}
777
778fn is_vectorized_broadcast_cast(elements: &SmallVec<[Arc<UOp>; 4]>) -> bool {
779    elements.first().is_some_and(|f| {
780        matches!(f.op(), Op::Cast { dtype: DType::Ptr { base, .. }, src }
781        if base.vcount() > 1 && is_define_local_or_reg_or_after(src))
782    })
783}
784
785/// Uses `unwrap_after()` to handle `.or_after()` pattern.
786fn is_define_local_or_reg_or_after(uop: &Arc<UOp>) -> bool {
787    matches!(uop.unwrap_after().op(), Op::DefineLocal(_) | Op::DefineReg { .. })
788}
789
790/// Vector pointer → scalar pointer + CAST (devectorizer.py:225-226).
791fn no_vectorized_buf(buf: &Arc<UOp>) -> Option<Arc<UOp>> {
792    let (base, addrspace, size) = buf.ptrdtype()?;
793    let vcount = base.vcount();
794    if vcount <= 1 {
795        return None;
796    }
797
798    let scalar_base = base.base();
799    let new_size = size.map(|s| s * vcount);
800    let scalar_ptr_dtype =
801        DType::Ptr { base: Box::new(DType::Scalar(scalar_base)), addrspace, size: new_size, vcount: 1 };
802
803    let scalar_def = buf.with_dtype(scalar_ptr_dtype);
804    Some(scalar_def.cast(buf.dtype()))
805}
806
807/// INDEX(CAST(buf), idx) → INDEX(VECTORIZE([buf,...]), scaled_vec_idx) (devectorizer.py:228-231)
808///
809/// Handles both scalar idx (original Tinygrad path) and vector idx (Morok extension).
810/// For vector idx with vcount=V and pointer vcount=cnt, produces total = V*cnt lanes:
811///   for each lane i in idx: idx[i]*cnt + [0, 1, ..., cnt-1]
812fn no_vectorized_index(
813    buf: &Arc<UOp>,
814    indices: &SmallVec<[Arc<UOp>; 4]>,
815    gate: &Option<Arc<UOp>>,
816    cast_dtype: &DType,
817) -> Option<Arc<UOp>> {
818    let idx = indices.first()?;
819    let DType::Ptr { base, .. } = cast_dtype else { return None };
820    let cnt = base.vcount();
821    if cnt <= 1 {
822        return None;
823    }
824
825    let idx_vcount = idx.dtype().vcount();
826    let total = cnt * idx_vcount;
827    let buf_broadcast = buf.broadcast(total);
828
829    let final_idx = if idx_vcount == 1 {
830        // Scalar path (original Tinygrad logic)
831        let idx_broadcast = idx.broadcast(cnt);
832        let cnt_broadcast = idx.const_like(cnt as i64).broadcast(cnt);
833        idx_broadcast.mul(&cnt_broadcast).add(&create_index_vector(0..cnt as i64))
834    } else {
835        // Vector path: expand each lane by cnt elements
836        // idx = [a, b, c], cnt = 3 → [a*3+0, a*3+1, a*3+2, b*3+0, b*3+1, b*3+2, c*3+0, c*3+1, c*3+2]
837        let elements: SmallVec<[Arc<UOp>; 4]> = (0..idx_vcount)
838            .flat_map(|i| {
839                let lane = idx.gep(vec![i]);
840                let cnt_const = UOp::const_(lane.dtype(), ConstValue::Int(cnt as i64));
841                let scaled = lane.mul(&cnt_const);
842                (0..cnt).map(move |j| scaled.add(&UOp::const_(scaled.dtype(), ConstValue::Int(j as i64))))
843            })
844            .collect();
845        UOp::vectorize(elements)
846    };
847
848    // Expand gate to match total lanes if it's vectorized
849    let expanded_gate = if idx_vcount > 1 {
850        gate.as_ref().map(|g| {
851            if g.dtype().vcount() > 1 {
852                let elements: SmallVec<[Arc<UOp>; 4]> = (0..idx_vcount)
853                    .flat_map(|i| {
854                        let lane = g.gep(vec![i]);
855                        std::iter::repeat_n(lane, cnt)
856                    })
857                    .collect();
858                UOp::vectorize(elements)
859            } else {
860                g.broadcast(total)
861            }
862        })
863    } else {
864        gate.clone()
865    };
866
867    Some(
868        UOp::index()
869            .buffer(buf_broadcast)
870            .indices(vec![final_idx])
871            .maybe_gate(expanded_gate)
872            .ptr(true)
873            .call()
874            .expect("ICE unable to create index"),
875    )
876}
877
878fn create_index_vector(values: impl IntoIterator<Item = i64>) -> Arc<UOp> {
879    let elements: SmallVec<[Arc<UOp>; 4]> = values.into_iter().map(UOp::index_const).collect();
880    UOp::vectorize(elements)
881}
882
883/// INDEX with precnt multiplier (broadcast or gep case) (devectorizer.py:233-239)
884///
885/// Handles both scalar and vector idx. For vector idx, each lane is expanded
886/// independently with the same precnt/cnt scaling.
887fn no_vectorized_index_precnt(
888    buf: &Arc<UOp>,
889    idx: &Arc<UOp>,
890    gate: &Option<Arc<UOp>>,
891    cnt: usize,
892    input_gep: &[usize],
893) -> Option<Arc<UOp>> {
894    let precnt = input_gep.len();
895    let idx_vcount = idx.dtype().vcount();
896
897    if idx_vcount == 1 {
898        // Scalar path (original Tinygrad logic)
899        let total = cnt * precnt;
900        let gep_arg: Vec<usize> = (0..cnt).flat_map(|_| 0..precnt).collect();
901        let sum_arg = (0..cnt).flat_map(|i| input_gep.iter().map(move |&y| (i + y) as i64));
902
903        let buf_broadcast = buf.broadcast(total);
904        let final_idx =
905            idx.gep(gep_arg).mul(&idx.const_like(cnt as i64).broadcast(total)).add(&create_index_vector(sum_arg));
906
907        Some(
908            UOp::index()
909                .buffer(buf_broadcast)
910                .indices(vec![final_idx])
911                .maybe_gate(gate.clone())
912                .ptr(true)
913                .call()
914                .expect("ICE: unable to create index"),
915        )
916    } else {
917        // Vector path: expand each lane with the same precnt/cnt scaling
918        let per_lane = cnt * precnt;
919        let total = per_lane * idx_vcount;
920
921        let buf_broadcast = buf.broadcast(total);
922        let elements: SmallVec<[Arc<UOp>; 4]> = (0..idx_vcount)
923            .flat_map(|i| {
924                let lane = idx.gep(vec![i]);
925                let cnt_const = UOp::const_(lane.dtype(), ConstValue::Int(cnt as i64));
926                let scaled = lane.mul(&cnt_const);
927                (0..cnt).flat_map(move |c| {
928                    let s = scaled.clone();
929                    input_gep.iter().map(move |&y| s.add(&UOp::const_(s.dtype(), ConstValue::Int((c + y) as i64))))
930                })
931            })
932            .collect();
933        let final_idx = UOp::vectorize(elements);
934
935        let expanded_gate = gate.as_ref().map(|g| {
936            if g.dtype().vcount() > 1 {
937                let elements: SmallVec<[Arc<UOp>; 4]> = (0..idx_vcount)
938                    .flat_map(|i| {
939                        let lane = g.gep(vec![i]);
940                        std::iter::repeat_n(lane, per_lane)
941                    })
942                    .collect();
943                UOp::vectorize(elements)
944            } else {
945                g.broadcast(total)
946            }
947        });
948
949        Some(
950            UOp::index()
951                .buffer(buf_broadcast)
952                .indices(vec![final_idx])
953                .maybe_gate(expanded_gate)
954                .ptr(true)
955                .call()
956                .expect("ICE: unable to create index"),
957        )
958    }
959}
960
961// ============================================================================
962// Load Store Indexing Patterns (devectorizer.py:48-55)
963// ============================================================================
964
965/// INDEX(buf, x, true) → INDEX(buf, x, None)
966///
967/// Tinygrad (devectorizer.py:48-55) has 2 additional IMAGE-specific patterns:
968///
969/// 1. `simplify_valid_load(buf, x, cond)` for `INDEX(buf, WHERE(cond, x, Invalid))`
970/// 2. `simplify_valid_load(buf, x, c)` for `INDEX(buf, x:long, c:bool)` (post-lowering)
971///
972/// These use `uop_given_valid`/`parse_valid` and are tied to ImageDType.
973/// TODO: Add when implementing IMAGE backend support.
974pub fn load_store_indexing_patterns() -> &'static TypedPatternMatcher {
975    crate::cached_patterns! {
976        // INDEX(buf, idx, true) → INDEX(buf, idx, None) — remove trivially-true gate.
977        // Uses UOp::new directly to preserve the original dtype without builder inference,
978        // since the builder's dtype logic (ptr flag, element extraction) may not match
979        // the already-determined dtype on the matched INDEX node.
980        index @ Index { buffer, indices, gate: Some(g) }
981            if matches!(g.op(), Op::Const(cv) if matches!(cv.0, ConstValue::Bool(true)))
982            ~> UOp::new(Op::Index { buffer: buffer.clone(), indices: indices.clone(), gate: None }, index.dtype())
983    }
984}
985
986// ============================================================================
987// Add Loads Patterns (devectorizer.py:320-326)
988// ============================================================================
989
990/// Add LOAD to non-pointer INDEX, remove LOAD wrapper from STORE.
991pub fn pm_add_loads() -> &'static TypedPatternMatcher {
992    crate::cached_patterns! {
993        // Add LOAD to non-ptr INDEX: INDEX(buf, idx) → LOAD(INDEX(buf, idx))
994        // Skip if dtype is already Ptr (devectorizer.py:322-323)
995        idx @ Index { buffer, .. } if !is_ptr_or_image_dtype(&idx.dtype()) => {
996            let new_idx = idx.with_dtype(buffer.dtype());
997            Some(UOp::load().buffer(buffer.clone()).index(new_idx).dtype(idx.dtype().scalar_dtype()).call())
998        },
999
1000        // Remove LOAD wrapper from STORE: STORE(LOAD(x), ...) → STORE(x, ...)
1001        // (devectorizer.py:325)
1002        Store { index: Load { index: inner_idx, .. }, value, ranges }
1003            => Some(inner_idx.store_with_ranges(value.clone(), ranges.clone())),
1004    }
1005}
1006
1007fn is_ptr_or_image_dtype(dtype: &DType) -> bool {
1008    matches!(dtype, DType::Ptr { .. } | DType::Image { .. })
1009}
1010
1011// ============================================================================
1012// WMMA Accumulation Patterns (devectorizer.py:314-315)
1013// ============================================================================
1014
1015/// Fuse Add into WMMA's accumulator: WMMA(a,b,c) + add → WMMA(a,b,c+add)
1016/// Tensor cores have built-in accumulation, so this is more efficient.
1017pub fn pm_wmma_accumulate() -> &'static TypedPatternMatcher {
1018    crate::cached_patterns! {
1019        // WMMA + add → WMMA with fused accumulator (devectorizer.py:314-315)
1020        // Pattern: Add(WMMA(a, b, c), add) → WMMA(a, b, Add(c, add))
1021        Add(wmma @ Wmma { a, b, c, metadata }, add) => |wmma, a, b, c, metadata, add| {
1022            // Only fuse if types match
1023            if wmma.dtype() != add.dtype() {
1024                return None;
1025            }
1026            let new_c = c.add(add);
1027            Some(UOp::wmma(a.clone(), b.clone(), new_c, metadata.clone()))
1028        },
1029
1030        // Commutative: add + WMMA → WMMA with fused accumulator
1031        Add(add, wmma @ Wmma { a, b, c, metadata }) => |wmma, add, a, b, c, metadata| {
1032            if wmma.dtype() != add.dtype() {
1033                return None;
1034            }
1035            let new_c = c.add(add);
1036            Some(UOp::wmma(a.clone(), b.clone(), new_c, metadata.clone()))
1037        },
1038    }
1039}
1040
1041// ============================================================================
1042// Load Store Folding Patterns (devectorizer.py:114-126)
1043// ============================================================================
1044/// Tinygrad load_store_folding (devectorizer.py:119-132).
1045/// 6 patterns in one matcher, exactly matching Tinygrad's order.
1046pub fn load_store_folding_patterns() -> &'static TypedPatternMatcher {
1047    crate::cached_patterns! {
1048        // 1. expand_index: INDEX(VECTORIZE(buf), vec) → VECTORIZE(INDEX(buf, gep(0)), ...)
1049        index if is_vector_index(index) => expand_index_to_vectorize(index),
1050
1051        // 2. fold_expanded_index: VECTORIZE(INDEX, INDEX, ...) → GEP(PTRCAT(...), indices)
1052        midx @ Vectorize { elements } if elements.iter().all(|e| matches!(e.op(), Op::Index { .. }))
1053            => fold_expanded_index(midx),
1054
1055        // 3. GEP after LOAD: LOAD(buf, GEP(x)) → GEP(LOAD(buf, x))
1056        load @ Load { buffer, index: Gep { vector, indices } }
1057            => move_gep_after_load(load, buffer, vector, indices),
1058
1059        // 4. GEP on STORE: STORE(GEP(x), data) → STORE(x, GEP⁻¹(data))
1060        Store { index: Gep { vector, indices }, value, ranges }
1061            => move_gep_on_store(vector, indices, value, ranges),
1062
1063        // 5. PTRCAT after LOAD: LOAD(buf, PTRCAT) → CAT(LOAD(buf_i, ptr_i), ...)
1064        load @ Load { buffer, index: ptrcat @ PtrCat { sources } }
1065            => distribute_ptrcat_load(load, buffer, ptrcat, sources),
1066
1067        // 6. PTRCAT after STORE: STORE(PTRCAT, data) → GROUP(STORE(ptr_i, gep(data, i)), ...)
1068        Store { index: PtrCat { sources }, value, ranges }
1069            => distribute_ptrcat_store(sources, value, ranges),
1070    }
1071}
1072
1073// ============================================================================
1074// Correct Load Store Patterns (devectorizer.py:198-203)
1075// ============================================================================
1076
1077/// LOAD/STORE(CAST(INDEX)) → split by device fold lengths + image fixup.
1078pub fn correct_load_store_patterns() -> &'static TypedPatternMatcher {
1079    crate::cached_patterns! {
1080        // Split LOAD/STORE by device fold lengths
1081        ls @ Load { index: Cast { src: idx @ Index { buffer: _, .. }, .. }, .. }
1082            => split_load_store(ls, idx),
1083
1084        ls @ Store { index: Cast { src: idx @ Index { buffer: _, .. }, .. }, .. }
1085            => split_load_store(ls, idx),
1086
1087        // Image fixup patterns (devectorizer.py:176-196)
1088        ls @ Load { buffer: _, index: _, alt: _ } => image_fixup(ls),
1089        ls @ Store { index: _, value: _, ranges: _ } => image_fixup(ls),
1090    }
1091}
1092
1093// ============================================================================
1094// Pattern Predicates
1095// ============================================================================
1096
1097fn is_define_or_after(uop: &Arc<UOp>) -> bool {
1098    matches!(uop.unwrap_after().op(), Op::DefineLocal(_) | Op::DefineReg { .. } | Op::Param { device: None, .. })
1099}
1100
1101/// Matches INDEX(VECTORIZE(Defines.or_after()), vec_idx) only.
1102/// Tinygrad devectorizer.py:115 - expand_index only matches VECTORIZE of defines.
1103fn is_vector_index(uop: &Arc<UOp>) -> bool {
1104    let Op::Index { buffer, indices, .. } = uop.op() else { return false };
1105    let Some(idx) = indices.first() else { return false };
1106    if idx.dtype().vcount() <= 1 {
1107        return false;
1108    }
1109    let Op::Vectorize { elements } = buffer.op() else { return false };
1110    !elements.is_empty() && elements.iter().all(is_define_or_after)
1111}
1112
1113// ============================================================================
1114// GEP Movement Patterns (devectorizer.py:106-120)
1115// ============================================================================
1116
1117/// LOAD(GEP(ptr)) → GEP(LOAD(ptr)).
1118///
1119/// Tinygrad (devectorizer.py:117-118):
1120/// ```python
1121/// lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count),
1122///                            src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)
1123/// ```
1124fn move_gep_after_load(
1125    load: &Arc<UOp>,
1126    buffer: &Arc<UOp>,
1127    gep_inner: &Arc<UOp>,
1128    gep_indices: &[usize],
1129) -> Option<Arc<UOp>> {
1130    let new_dtype = load.dtype().scalar_dtype().vec(gep_indices.len());
1131    let inner_load = load.replace().dtype(new_dtype).src(vec![buffer.clone(), gep_inner.clone()]).call();
1132    Some(inner_load.gep(gep_indices.to_vec()))
1133}
1134
1135/// STORE(GEP(ptr), data) → STORE(ptr, GEP⁻¹(data)). Inverts GEP indices.
1136fn move_gep_on_store(
1137    gep_inner: &Arc<UOp>,
1138    gep_indices: &[usize],
1139    value: &Arc<UOp>,
1140    ranges: &SmallVec<[Arc<UOp>; 4]>,
1141) -> Option<Arc<UOp>> {
1142    // Invert GEP: [2,0,1] → sorted by key → [1,2,0]
1143    let mut inverse_map: Vec<(usize, usize)> = gep_indices.iter().enumerate().map(|(i, &x)| (x, i)).collect();
1144    inverse_map.sort_by_key(|&(x, _)| x);
1145    let inverse_indices: Vec<usize> = inverse_map.iter().map(|&(_, i)| i).collect();
1146
1147    let reordered_value = value.gep(inverse_indices);
1148    Some(gep_inner.store_with_ranges(reordered_value, ranges.clone()))
1149}
1150
1151// ============================================================================
1152// expand_index (devectorizer.py:59-95)
1153// ============================================================================
1154
1155/// Vector INDEX → grouped PTRCAT. Generates scalar indices, simplifies, groups by root+offset.
1156/// Phase 1a: Expand vector INDEX into VECTORIZE of scalar INDEXes.
1157/// Matches Tinygrad's `expand_index` (devectorizer.py:59-62).
1158/// NO inner rewrite — the outer fixed-point (sym) simplifies GEP expressions.
1159fn expand_index_to_vectorize(index: &Arc<UOp>) -> Option<Arc<UOp>> {
1160    let Op::Index { buffer, indices, gate } = index.op() else { return None };
1161    assert!(indices.len() <= 1, "ICE: expand_index_to_vectorize called with multi-index INDEX (len={})", indices.len());
1162    let vec = indices.first()?;
1163    let count = vec.dtype().vcount();
1164
1165    let buf = if let Op::Vectorize { elements } = buffer.op() { elements.first()?.clone() } else { buffer.clone() };
1166
1167    let scalar_indices: Vec<_> = (0..count)
1168        .map(|i| {
1169            let lane_gate = gate.as_ref().map(|g| if g.dtype().vcount() > 1 { g.gep(vec![i]) } else { g.clone() });
1170            UOp::index()
1171                .buffer(buf.clone())
1172                .indices(vec![vec.gep(vec![i])])
1173                .maybe_gate(lane_gate)
1174                .ptr(true)
1175                .call()
1176                .expect("ICE: unable to create index")
1177        })
1178        .collect();
1179
1180    Some(UOp::vectorize(scalar_indices.into()))
1181}
1182
1183/// Phase 1b: Fold VECTORIZE of scalar INDEXes into PTRCAT groupings.
1184/// Matches Tinygrad's `fold_expanded_index` (devectorizer.py:64-100).
1185/// By this point, the outer sym fixed-point has simplified GEP expressions
1186/// into concrete root+offset form so grouping can identify consecutive accesses.
1187fn fold_expanded_index(midx: &Arc<UOp>) -> Option<Arc<UOp>> {
1188    let Op::Vectorize { elements: sources } = midx.op() else { return None };
1189    let count = sources.len();
1190    if count == 0 {
1191        return None;
1192    }
1193
1194    // Verify all elements are INDEX and share the same buffer
1195    let first_buf = match sources[0].op() {
1196        Op::Index { buffer, .. } => buffer,
1197        _ => return None,
1198    };
1199    if !sources.iter().all(|s| matches!(s.op(), Op::Index { buffer, .. } if Arc::ptr_eq(buffer, first_buf))) {
1200        return None;
1201    }
1202    let buf = first_buf;
1203
1204    // Extract (valid, root, offset, gate) for each lane.
1205    struct LaneData {
1206        valid: Arc<UOp>,
1207        root: Arc<UOp>,
1208        offset: i64,
1209        gate_id: u64,
1210    }
1211    let mut lane_data: Vec<(usize, LaneData)> = Vec::with_capacity(count);
1212
1213    for (lane, idx_op) in sources.iter().enumerate() {
1214        let Op::Index { indices: simp_indices, gate: lane_gate, .. } = idx_op.op() else { continue };
1215        let idx = simp_indices.first()?.get_idx();
1216        let valid = simp_indices.first()?.get_valid();
1217        let gate_id = lane_gate.as_ref().map_or(u64::MAX, |g| g.id);
1218
1219        let (root, offset) = match idx.op() {
1220            Op::Invalid => (UOp::invalid_marker(), 0),
1221            Op::Binary(BinaryOp::Add, l, r) if matches!(r.op(), Op::Const(cv) if matches!(cv.0, ConstValue::Int(_))) => {
1222                let Op::Const(cv) = r.op() else { unreachable!() };
1223                let ConstValue::Int(off) = cv.0 else { unreachable!() };
1224                (l.clone(), off)
1225            }
1226            Op::Binary(BinaryOp::Add, l, r) if matches!(l.op(), Op::Const(cv) if matches!(cv.0, ConstValue::Int(_))) => {
1227                let Op::Const(cv) = l.op() else { unreachable!() };
1228                let ConstValue::Int(off) = cv.0 else { unreachable!() };
1229                (r.clone(), off)
1230            }
1231            Op::Const(cv) if matches!(cv.0, ConstValue::Int(_)) => {
1232                let ConstValue::Int(off) = cv.0 else { unreachable!() };
1233                (UOp::index_const(0), off)
1234            }
1235            _ => (idx.clone(), 0),
1236        };
1237
1238        lane_data.push((lane, LaneData { valid, root, offset, gate_id }));
1239    }
1240
1241    // Build grouping map
1242    let mut offsets_by_root: HashMap<(u64, u64, u64), HashMap<i64, Vec<usize>>> = HashMap::new();
1243    for (lane, data) in &lane_data {
1244        let key = (data.valid.id, data.root.id, data.gate_id);
1245        offsets_by_root.entry(key).or_default().entry(data.offset).or_default().push(*lane);
1246    }
1247
1248    // Group consecutive offsets and build PTRCAT
1249    let mut ret = Vec::new();
1250    let mut idxs: Vec<Option<usize>> = vec![None; count];
1251    let mut global_offset = 0;
1252
1253    for offsets in offsets_by_root.values() {
1254        let groups = group_consecutive_offsets_from_map(offsets);
1255        for grp in groups {
1256            let lidx = sources[offsets[&grp[0]][0]].clone();
1257            let ptr = if grp.len() > 1 { lidx.cast(make_vec_ptr_dtype(buf, grp.len())) } else { lidx };
1258            for (i, &offset) in grp.iter().enumerate() {
1259                for &lane in &offsets[&offset] {
1260                    idxs[lane] = Some(global_offset + i);
1261                }
1262            }
1263            ret.push(ptr);
1264            global_offset += grp.len();
1265        }
1266    }
1267
1268    if idxs.iter().any(|x| x.is_none()) {
1269        return None;
1270    }
1271
1272    let DType::Ptr { base, addrspace, size, .. } = buf.dtype().clone() else { return None };
1273    let scalar_ptr = DType::Ptr { base: Box::new(DType::Scalar(base.scalar()?)), addrspace, size, vcount: 1 };
1274    let ptrcat_dtype = scalar_ptr.vec(global_offset);
1275    let ptrcat = UOp::ptrcat().sources(ret).dtype(ptrcat_dtype).call();
1276    let gep_indices: Vec<usize> = idxs.into_iter().map(|x| x.unwrap()).collect();
1277
1278    Some(ptrcat.gep(gep_indices))
1279}
1280
1281/// Groups offsets where `offset - index` is constant.
1282///
1283/// Returns groups of consecutive offset keys. Multi-lane offsets (broadcasts)
1284/// are handled by the caller — all lanes sharing an offset get the same PTRCAT slot.
1285/// Matches Tinygrad's `itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])`.
1286fn group_consecutive_offsets_from_map(offsets_map: &HashMap<i64, Vec<usize>>) -> Vec<Vec<i64>> {
1287    if offsets_map.is_empty() {
1288        return vec![];
1289    }
1290
1291    let sorted: Vec<_> = offsets_map.keys().copied().sorted().collect();
1292    sorted
1293        .iter()
1294        .copied()
1295        .enumerate()
1296        .chunk_by(|(idx, offset)| offset - (*idx as i64))
1297        .into_iter()
1298        .map(|(_, group)| group.map(|(_, offset)| offset).collect())
1299        .collect()
1300}
1301
1302fn make_vec_ptr_dtype(buffer: &Arc<UOp>, vec_len: usize) -> DType {
1303    let (base_dtype, addrspace) = buffer
1304        .ptrdtype()
1305        .map(|(base, addrspace, _)| (base.base(), addrspace))
1306        .unwrap_or_else(|| (buffer.dtype().base(), AddrSpace::Global));
1307    let vec_dtype = DType::Vector { scalar: base_dtype, count: vec_len };
1308    DType::Ptr { base: Box::new(vec_dtype), addrspace, size: Some(vec_len), vcount: 1 }
1309}
1310
1311// ============================================================================
1312// PTRCAT Distribution (devectorizer.py:97-104, 122-123)
1313// ============================================================================
1314
1315/// LOAD(PTRCAT) → CAT(LOADs). CAT dtype = ptrcat.dtype.base.vec(ptrcat.dtype.vcount)
1316/// LOAD(buf, PTRCAT(idx0,idx1,...)) → CAT(LOAD(buf_i, idx_i), ...)
1317///
1318/// Matches Tinygrad devectorizer.py:128-129:
1319///   ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src
1320///
1321/// Each PtrCat source is a scalar INDEX(buf, offset). The distributed scalar
1322/// LOAD uses that INDEX directly, with the scalar buffer from GEP(buffer, i).
1323fn distribute_ptrcat_load(
1324    load: &Arc<UOp>,
1325    buffer: &Arc<UOp>,
1326    ptrcat: &Arc<UOp>,
1327    sources: &[Arc<UOp>],
1328) -> Option<Arc<UOp>> {
1329    let loads: Vec<Arc<UOp>> = sources
1330        .iter()
1331        .enumerate()
1332        .map(|(i, ptr)| {
1333            let load_dtype = match ptr.dtype() {
1334                DType::Ptr { base, .. } => base.as_ref().clone(),
1335                other => other.clone(),
1336            };
1337            // Extract scalar buffer for this lane.
1338            // Tinygrad: ld.replace(src=(x,)+ld.src[1:]) — PtrCat source IS the full
1339            // INDEX(buf, idx), so the scalar load doesn't need the outer buffer at all.
1340            // Each PtrCat source already contains its own buffer reference.
1341            // For VECTORIZE(buf, buf, ...) just use the scalar element.
1342            let scalar_buf = match buffer.op() {
1343                Op::Vectorize { elements, .. } => elements.get(i).cloned().unwrap_or_else(|| buffer.clone()),
1344                _ => buffer.clone(),
1345            };
1346            let alt = match load.op() {
1347                Op::Load { alt, .. } => alt.clone(),
1348                _ => None,
1349            };
1350            UOp::load().buffer(scalar_buf).index(ptr.clone()).maybe_alt(alt).dtype(load_dtype).call()
1351        })
1352        .collect();
1353
1354    let cat_dtype = DType::Scalar(ptrcat.dtype().base()).vec(ptrcat.dtype().vcount());
1355    Some(UOp::cat().sources(loads).dtype(cat_dtype).call())
1356}
1357
1358/// STORE(PTRCAT, data) → GROUP(STOREs with GEP-sliced data)
1359fn distribute_ptrcat_store(
1360    sources: &[Arc<UOp>],
1361    value: &Arc<UOp>,
1362    ranges: &SmallVec<[Arc<UOp>; 4]>,
1363) -> Option<Arc<UOp>> {
1364    let value_vcount = value.dtype().vcount();
1365    let mut stores = Vec::new();
1366    let mut offset = 0usize;
1367
1368    for ptr in sources.iter() {
1369        let ptr_count = ptr_element_count(ptr);
1370        debug_assert!(offset + ptr_count <= value_vcount, "PTRCAT size mismatch");
1371        let gep_indices: Vec<usize> = (offset..offset + ptr_count).collect();
1372        let store_value = value.gep(gep_indices);
1373        stores.push(ptr.store_with_ranges(store_value, ranges.clone()));
1374        offset += ptr_count;
1375    }
1376
1377    Some(UOp::group(stores.into_iter().collect()))
1378}
1379
1380/// Get the element count for a PTRCAT source pointer.
1381///
1382/// This should return the vcount of the base type, NOT the buffer size.
1383/// For `Ptr { base: Scalar(Float32), size: Some(4), .. }` → 1 (scalar access)
1384/// For `Ptr { base: Vector { count: 2, .. }, size: Some(2), .. }` → 2 (vec2 access)
1385///
1386/// Tinygrad uses `dtype.count` which returns the base type's element count.
1387fn ptr_element_count(ptr: &Arc<UOp>) -> usize {
1388    match ptr.dtype() {
1389        DType::Ptr { base, .. } => base.vcount(),
1390        _ => 1,
1391    }
1392}
1393
1394// ============================================================================
1395// split_load_store (devectorizer.py:130-174)
1396// ============================================================================
1397
1398/// Split LOAD/STORE into multiple chunks by device fold lengths (devectorizer.py:130-174).
1399fn split_load_store(ls: &Arc<UOp>, idx: &Arc<UOp>) -> Option<Arc<UOp>> {
1400    let Op::Index { buffer: buf, indices, .. } = idx.op() else { return None };
1401
1402    // sz = ls.src[0].dtype.count (Tinygrad: size from index dtype)
1403    // For Ptr types, we need base.vcount() — the pointee's vector count.
1404    // index.dtype().vcount() returns the pointer's vector count (always 1 for CAST'd pointers).
1405    let sz = match ls.op() {
1406        Op::Load { index, .. } | Op::Store { index, .. } => ptr_element_count(index),
1407        _ => return None,
1408    };
1409    if sz == 1 {
1410        return None;
1411    }
1412
1413    // Fold lengths (devectorizer.py:138-152)
1414    let buf_dtype = buf.dtype();
1415    static IS_AMX: std::sync::LazyLock<bool> =
1416        std::sync::LazyLock::new(|| std::env::var("MOROK_AMX").is_ok_and(|v| v == "1"));
1417    let is_amx = *IS_AMX;
1418
1419    // AMX TC accumulators are stored in DEFINE_REG (AddrSpace::Reg) but need vector stores.
1420    // For STORE: check if VALUE comes from an AMX TC accumulator (DEFINE_REG with AddrSpace::Reg).
1421    // For LOAD: check if BUFFER is an AMX TC accumulator.
1422    fn is_amx_tc_reg_ptr(dtype: &DType, sz: usize) -> bool {
1423        sz >= 16
1424            && dtype.base().is_float()
1425            && matches!(dtype, DType::Ptr { addrspace: AddrSpace::Reg, .. } | DType::Vector { .. })
1426    }
1427
1428    // Helper to find underlying LOAD through GEP chains
1429    fn find_underlying_load(uop: &Arc<UOp>) -> Option<Arc<UOp>> {
1430        match uop.op() {
1431            Op::Gep { vector, .. } => find_underlying_load(vector),
1432            Op::Load { .. } => Some(uop.clone()),
1433            _ => None,
1434        }
1435    }
1436
1437    let is_amx_tc_acc = match ls.op() {
1438        Op::Store { value, .. } => {
1439            // Check if value comes from LOAD of DEFINE_REG (AMX accumulator)
1440            // Value may be GEP(LOAD(...)), so trace through GEP chains
1441            if let Some(load) = find_underlying_load(value) {
1442                if let Op::Load { index, .. } = load.op() {
1443                    if let Op::Index { buffer, .. } = index.op() {
1444                        let buf_dtype = buffer.dtype();
1445                        is_amx && is_amx_tc_reg_ptr(&buf_dtype, sz)
1446                    } else {
1447                        false
1448                    }
1449                } else {
1450                    false
1451                }
1452            } else {
1453                false
1454            }
1455        }
1456        Op::Load { .. } => is_amx && is_amx_tc_reg_ptr(&buf_dtype, sz),
1457        _ => false,
1458    };
1459
1460    // Don't fold for non-float types or Image, but allow AMX TC accumulators.
1461    // Tinygrad: no_fold is False for AMX operations since they use in-memory accumulation.
1462    let no_fold = (!buf_dtype.base().is_float() && !matches!(buf_dtype, DType::Image { .. }))
1463        || (matches!(buf_dtype, DType::Ptr { addrspace: AddrSpace::Reg, .. }) && !is_amx_tc_acc);
1464
1465    let mut lengths = if no_fold {
1466        vec![1]
1467    } else if matches!(buf_dtype, DType::Image { .. }) {
1468        vec![4, 1]
1469    } else if is_amx {
1470        vec![16, 8, 4, 2, 1] // AMX: wider folds matching 64-byte row stride
1471    } else {
1472        // Tinygrad uses ctx.supports_float4 from Renderer context (devectorizer.py:155-157).
1473        // Hardcoded [4,2,1] matches the default supports_float4 path.
1474        // TODO: Pass Renderer context through when adding backends with different fold lengths.
1475        vec![4, 2, 1]
1476    };
1477
1478    // Filter by divisibility (devectorizer.py:155-156)
1479    // NOTE: Tinygrad has `must_divide=False` for DSP devices which skips this check.
1480    // DSP uses larger fold lengths [128,64,32,16,8,4] without divisibility requirement.
1481    // AMX TC accumulators also skip divisibility check to allow vector stores.
1482    if !is_amx_tc_acc && let Some(offset) = indices.first() {
1483        lengths.retain(|&len| offset_divides_evenly(offset, len));
1484    }
1485
1486    // Split loop (devectorizer.py:159-170)
1487    let scalar_dtype = buf_dtype.scalar_dtype();
1488    let mut ret = Vec::new();
1489    let mut pos = 0usize;
1490
1491    while pos < sz {
1492        for &fold_len in &lengths {
1493            if pos + fold_len > sz {
1494                continue;
1495            }
1496            let lidx = if pos == 0 { idx.clone() } else { offset_index(idx, pos as i64) };
1497            let lidx = if fold_len > 1 { lidx.cast(make_vec_ptr_dtype(buf, fold_len)) } else { lidx };
1498
1499            match ls.op() {
1500                Op::Store { value, ranges, .. } => {
1501                    ret.push(lidx.store_with_ranges(value.gep((pos..pos + fold_len).collect()), ranges.clone()));
1502                }
1503                Op::Load { buffer, .. } => {
1504                    ret.push(UOp::load().buffer(buffer.clone()).index(lidx).dtype(scalar_dtype.vec(fold_len)).call());
1505                }
1506                _ => return None,
1507            }
1508            pos += fold_len;
1509            break;
1510        }
1511    }
1512
1513    if ret.len() <= 1 {
1514        return None;
1515    }
1516
1517    match ls.op() {
1518        Op::Load { .. } => Some(UOp::cat().sources(ret).dtype(scalar_dtype.vec(sz)).call()),
1519        Op::Store { .. } => Some(UOp::group(ret.into_iter().collect())),
1520        _ => None,
1521    }
1522}
1523
1524/// Check if offset expression divides evenly by len (devectorizer.py:703-711).
1525/// Conservative: false for unknown expressions.
1526fn offset_divides_evenly(offset: &Arc<UOp>, len: usize) -> bool {
1527    // len==0 is invalid (can't divide by zero), return false defensively
1528    if len == 0 {
1529        return false;
1530    }
1531    // len==1 means no vectorization, trivially true
1532    if len == 1 {
1533        return true;
1534    }
1535    let v = len as i64;
1536
1537    match offset.op() {
1538        // CONST: check modulo
1539        Op::Const(cv) => matches!(cv.0, ConstValue::Int(n) if n % v == 0),
1540
1541        // VCONST: all elements must divide evenly
1542        Op::VConst { values } => values.iter().all(|val| matches!(val, ConstValue::Int(n) if n % v == 0)),
1543
1544        // ADD: both operands must divide
1545        Op::Binary(BinaryOp::Add, left, right) => offset_divides_evenly(left, len) && offset_divides_evenly(right, len),
1546
1547        // MUL: either operand divides (matching Tinygrad - no n >= len check!)
1548        Op::Binary(BinaryOp::Mul, left, right) => {
1549            let check_const =
1550                |c: &Arc<UOp>| matches!(c.op(), Op::Const(cv) if matches!(cv.0, ConstValue::Int(n) if n % v == 0));
1551            check_const(left)
1552                || check_const(right)
1553                || offset_divides_evenly(left, len)
1554                || offset_divides_evenly(right, len)
1555        }
1556
1557        _ => false,
1558    }
1559}
1560
1561fn offset_index(idx: &Arc<UOp>, offset: i64) -> Arc<UOp> {
1562    let Op::Index { buffer, indices, gate } = idx.op() else {
1563        return idx.clone();
1564    };
1565    let new_indices: SmallVec<[Arc<UOp>; 4]> = indices
1566        .iter()
1567        .enumerate()
1568        .map(|(i, index_expr)| if i == 0 { index_expr.add(&index_expr.const_like(offset)) } else { index_expr.clone() })
1569        .collect();
1570
1571    UOp::index()
1572        .buffer(buffer.clone())
1573        .indices(new_indices)
1574        .maybe_gate(gate.clone())
1575        .ptr(true)
1576        .call()
1577        .expect("ICE: unable to create index")
1578}
1579
1580// ============================================================================
1581// image_fixup (devectorizer.py:176-196)
1582// ============================================================================
1583
1584/// Convert linear image index to 2D (x, y) coordinates.
1585///
1586/// For images with shape [height, width]:
1587///   x_coord = (linear_idx // 4) % width
1588///   y_coord = linear_idx // (4 * width)
1589///
1590/// Handles two cases:
1591/// 1. Normal image load/store with CAST from expand_index (dtype.count == 4)
1592/// 2. Unfoldable image load (no CAST, direct INDEX with ImageDType)
1593fn image_fixup(ls: &Arc<UOp>) -> Option<Arc<UOp>> {
1594    // Case 1: LOAD/STORE(CAST(INDEX)) where INDEX.buffer is ImageDType
1595    // The CAST should be to a vec4 pointer
1596    let (index, is_load) = match ls.op() {
1597        Op::Load { index, .. } => (index, true),
1598        Op::Store { index, .. } => (index, false),
1599        _ => return None,
1600    };
1601
1602    // Check for CAST(INDEX) pattern
1603    if let Op::Cast { src: inner_idx, dtype: cast_dtype } = index.op()
1604        && let Op::Index { buffer: img_buf, indices, gate } = inner_idx.op()
1605    {
1606        // Check if buffer is ImageDType
1607        let DType::Image { shape, .. } = img_buf.dtype() else { return None };
1608
1609        // Image must be casted to vec4 (RGBA)
1610        if cast_dtype.vcount() != 4 {
1611            return None;
1612        }
1613
1614        // Get the first index (linear index)
1615        let lin_idx = indices.first()?;
1616        let x = lin_idx.get_idx();
1617        let valid = lin_idx.get_valid();
1618
1619        // Get image width (shape[1])
1620        let width = shape.get(1).copied().unwrap_or(1) as i64;
1621
1622        // Create 2D index: x_coord = (x // 4) % width, y_coord = x // (4 * width)
1623        let four = UOp::index_const(4);
1624        let width_const = UOp::index_const(width);
1625        let stride = UOp::index_const(4 * width);
1626
1627        let x_coord = x.idiv(&four).mod_(&width_const);
1628        let y_coord = x.idiv(&stride);
1629
1630        // Create vec2 index
1631        let oidx = UOp::vectorize(smallvec::smallvec![x_coord, y_coord]);
1632
1633        // Apply validity if not always true
1634        let new_idx_expr = if matches!(valid.op(), Op::Const(cv) if cv.0 == ConstValue::Bool(true)) {
1635            oidx
1636        } else {
1637            oidx.valid(valid)
1638        };
1639
1640        // Create new INDEX with 2D coordinates
1641        // Use ptr(true) when inner_idx has Ptr dtype, otherwise preserve element dtype
1642        let new_idx = if matches!(inner_idx.dtype(), DType::Ptr { .. }) {
1643            UOp::index()
1644                .buffer(img_buf.clone())
1645                .indices(vec![new_idx_expr])
1646                .maybe_gate(gate.clone())
1647                .ptr(true)
1648                .call()
1649                .ok()?
1650        } else {
1651            UOp::index()
1652                .buffer(img_buf.clone())
1653                .indices(vec![new_idx_expr])
1654                .maybe_gate(gate.clone())
1655                .dtype(inner_idx.dtype())
1656                .call()
1657                .ok()?
1658        };
1659
1660        // Replace the index in LOAD/STORE
1661        return Some(ls.replace().src(vec![new_idx]).call());
1662    }
1663
1664    // Case 2: Direct INDEX with ImageDType (unfoldable image, no CAST)
1665    if let Op::Index { buffer: img_buf, indices, gate } = index.op() {
1666        let DType::Image { shape, .. } = img_buf.dtype() else { return None };
1667
1668        // Get the first index
1669        let lin_idx = indices.first()?;
1670        let x = lin_idx.get_idx();
1671
1672        // Check if it's already a 2D index (vec2)
1673        if x.dtype().vcount() == 2 {
1674            return None; // Already converted
1675        }
1676
1677        // Only LOAD is supported for unfoldable images
1678        if !is_load {
1679            tracing::warn!("image_fixup: STORE with unfoldable image not supported");
1680            return None;
1681        }
1682
1683        let valid = lin_idx.get_valid();
1684
1685        // Get image width
1686        let width = shape.get(1).copied().unwrap_or(1) as i64;
1687
1688        // Create 2D index
1689        let four = UOp::index_const(4);
1690        let width_const = UOp::index_const(width);
1691        let stride = UOp::index_const(4 * width);
1692
1693        let x_coord = x.idiv(&four).mod_(&width_const);
1694        let y_coord = x.idiv(&stride);
1695
1696        let oidx = UOp::vectorize(smallvec::smallvec![x_coord, y_coord]);
1697
1698        let new_idx_expr = if matches!(valid.op(), Op::Const(cv) if cv.0 == ConstValue::Bool(true)) {
1699            oidx
1700        } else {
1701            oidx.valid(valid)
1702        };
1703
1704        // Use ptr(true) when index has Ptr dtype, otherwise preserve element dtype
1705        let new_idx = if matches!(index.dtype(), DType::Ptr { .. }) {
1706            UOp::index()
1707                .buffer(img_buf.clone())
1708                .indices(vec![new_idx_expr])
1709                .maybe_gate(gate.clone())
1710                .ptr(true)
1711                .call()
1712                .ok()?
1713        } else {
1714            UOp::index()
1715                .buffer(img_buf.clone())
1716                .indices(vec![new_idx_expr])
1717                .maybe_gate(gate.clone())
1718                .dtype(index.dtype())
1719                .call()
1720                .ok()?
1721        };
1722
1723        // For unfoldable images: load vec4, then select correct element
1724        // result = reduce(lambda ret, i: (x % 4).ne(i).where(ret, vec_load.gep(i)), range(4), nan)
1725        let vec4_dtype = ls.dtype().vec(4);
1726        let vec_load = UOp::load().buffer(ls.load_buffer()?).index(new_idx).dtype(vec4_dtype).call();
1727
1728        // Build: WHERE(x%4 != 0, WHERE(x%4 != 1, WHERE(x%4 != 2, WHERE(x%4 != 3, nan, gep3), gep2), gep1), gep0)
1729        let x_mod_4 = x.mod_(&four);
1730        let nan = ls.const_like(ConstValue::Float(f64::NAN));
1731
1732        let result = (0..4).rev().fold(nan, |ret, i| {
1733            let i_const = UOp::index_const(i);
1734            let not_eq = x_mod_4.ne(&i_const);
1735            let gep_i = vec_load.gep(vec![i as usize]);
1736            UOp::try_where(not_eq, ret, gep_i).expect("WHERE")
1737        });
1738
1739        return Some(result);
1740    }
1741
1742    None
1743}
1744
1745// ============================================================================
1746// pm_reduce: Convert REDUCE to explicit accumulator pattern (Tinygrad devectorizer.py:310-316)
1747// ============================================================================
1748
1749use crate::symbolic::dce::reduce_identity;
1750
1751/// Convert REDUCE to explicit DEFINE_REG + LOAD/STORE accumulation pattern.
1752///
1753/// Transforms:
1754/// ```text
1755/// REDUCE(src, ranges, Add) with dtype Float32
1756/// ```
1757///
1758/// To:
1759/// ```text
1760/// acc = DEFINE_REG_TYPED(1, Float32)
1761/// idx = INDEX(acc, [0])
1762/// store_init = STORE(acc, idx, identity)  // Initialize with 0 for Add
1763/// // Loop body (ranges provide iteration):
1764/// acc_after = AFTER(acc, [store_init, ranges...])
1765/// idx_loop = INDEX(acc_after, [0])
1766/// val = LOAD(acc, idx_loop)
1767/// new_val = val + src
1768/// store_loop = STORE(acc, idx_loop, new_val)
1769/// // After loop:
1770/// end = END(store_loop, ranges)
1771/// acc_final = AFTER(acc, [end])
1772/// idx_final = INDEX(acc_final, [0])
1773/// result = LOAD(acc, idx_final)
1774/// ```
1775///
1776/// This runs EARLY (before pm_add_loads, before main devectorize) to eliminate
1777/// REDUCE before other patterns see it. Matches Tinygrad's pm_reduce.
1778pub fn pm_reduce() -> TypedPatternMatcher<ReduceContext> {
1779    crate::patterns! {
1780        @context ReduceContext;
1781
1782        // Match ALL REDUCEs - empty ranges handled by returning reduced value directly
1783        red @ Reduce(_, ..) => {
1784            reduce_to_acc(red, ctx)
1785        },
1786
1787        // Merge END nodes sharing the same reduce ranges (Tinygrad merge_reduce_ends)
1788        Sink { sources: _sources } => {
1789            ctx.merge_reduce_ends(_sources)
1790        },
1791    }
1792}
1793
1794/// Horizontal reduce for accumulator pattern (devectorizer.py:283-289).
1795fn horizontal_reduce(inp: &Arc<UOp>, out_dtype: &DType) -> Vec<Arc<UOp>> {
1796    if inp.dtype() == *out_dtype {
1797        return vec![inp.clone()];
1798    }
1799    let inp_vcount = inp.dtype().vcount();
1800    let out_vcount = out_dtype.vcount();
1801    assert!(
1802        inp_vcount.is_multiple_of(out_vcount),
1803        "horizontal mismatch: inp.dtype={:?} (vcount={}), out_dtype={:?} (vcount={})",
1804        inp.dtype(),
1805        inp_vcount,
1806        out_dtype,
1807        out_vcount
1808    );
1809    let horizontal_amount = inp_vcount / out_vcount;
1810    (0..horizontal_amount).map(|i| inp.gep((i..inp_vcount).step_by(horizontal_amount).collect())).collect()
1811}
1812
1813/// Convert REDUCE to explicit accumulator pattern (devectorizer.py:291-308).
1814fn reduce_to_acc(red: &Arc<UOp>, ctx: &mut ReduceContext) -> Option<Arc<UOp>> {
1815    let Op::Reduce { src: inp, ranges: reduce_range, reduce_op } = red.op() else { return None };
1816
1817    let lst = horizontal_reduce(inp, &red.dtype());
1818    debug_assert!(lst.iter().all(|x| x.dtype() == red.dtype()), "horizontal reduction mismatch");
1819
1820    // No ranges → just horizontal reduction
1821    if reduce_range.is_empty() {
1822        return lst.into_iter().reduce(|a, b| apply_reduce_binary(*reduce_op, a, b, &red.dtype()));
1823    }
1824
1825    // Find input_ranges: ranges in topo that are not reduce_range and not ended
1826    let topo = inp.toposort();
1827    let ended: HashSet<u64> = topo
1828        .iter()
1829        .filter_map(|n| if let Op::End { ranges, .. } = n.op() { Some(ranges.iter().map(|r| r.id)) } else { None })
1830        .flatten()
1831        .collect();
1832    let reduce_ids: HashSet<u64> = reduce_range.iter().map(|r| r.id).collect();
1833    let input_ranges: SmallVec<[Arc<UOp>; 4]> = topo
1834        .iter()
1835        .filter(|n| matches!(n.op(), Op::Range { .. }) && !reduce_ids.contains(&n.id) && !ended.contains(&n.id))
1836        .cloned()
1837        .collect();
1838
1839    // Set up accumulator
1840    let identity = reduce_identity(*reduce_op, red.dtype());
1841    let acc = UOp::define_reg_typed(1, red.dtype());
1842    let zero = UOp::index_const(0);
1843    let make_idx = |buf: Arc<UOp>| UOp::index().buffer(buf).indices(vec![zero.clone()]).call().expect("index");
1844
1845    // acc_init = acc.after(*input_ranges).index(0).store(identity)
1846    let acc_init = make_idx(acc.after(input_ranges)).store_value(identity);
1847
1848    // lst = [acc.after(acc_init, *reduce_range).index(0)] + lst
1849    let mut loop_deps: SmallVec<[Arc<UOp>; 4]> = smallvec::smallvec![acc_init];
1850    loop_deps.extend(reduce_range.iter().cloned());
1851    let acc_loop = make_idx(acc.after(loop_deps));
1852    let lst_with_acc = std::iter::once(acc_loop).chain(lst);
1853
1854    // ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
1855    let ret = lst_with_acc.reduce(|a, b| apply_reduce_binary(*reduce_op, a, b, &red.dtype()))?;
1856
1857    // return acc.after(acc.index(0).store(ret).end(*reduce_range)).index(0)
1858    let store_end = make_idx(acc.clone()).store_value(ret).end(reduce_range.clone());
1859    ctx.register_end(&store_end);
1860    Some(make_idx(acc.after(smallvec::smallvec![store_end])))
1861}
1862
1863/// Apply binary reduce operation between two values.
1864fn apply_reduce_binary(reduce_op: ReduceOp, a: Arc<UOp>, b: Arc<UOp>, dtype: &DType) -> Arc<UOp> {
1865    debug_assert!(a.dtype() == b.dtype(), "reduce operand dtype mismatch");
1866    match reduce_op {
1867        ReduceOp::Add => UOp::new(Op::Binary(BinaryOp::Add, a, b), dtype.clone()),
1868        ReduceOp::Mul => UOp::new(Op::Binary(BinaryOp::Mul, a, b), dtype.clone()),
1869        ReduceOp::Max => UOp::new(Op::Binary(BinaryOp::Max, a, b), dtype.clone()),
1870        ReduceOp::Min => {
1871            let cond = UOp::new(Op::Binary(BinaryOp::Lt, a.clone(), b.clone()), DType::Bool.vec(dtype.vcount()));
1872            UOp::try_where(cond, a, b).expect("WHERE")
1873        }
1874    }
1875}