1use std::collections::HashMap;
16use std::collections::HashSet;
17use std::sync::Arc;
18use std::sync::LazyLock;
19
20use itertools::Itertools;
21use morok_dtype::{AddrSpace, DType, ScalarDType};
22use morok_ir::{BinaryOp, ConstValue, Op, ReduceOp, TernaryOp, UOp, UOpKey, UnaryOp, WmmaMetadata};
23
24use crate::TypedPatternMatcher;
25use smallvec::SmallVec;
26
27#[derive(Debug, Default)]
32pub struct ReduceContext {
33 range_to_ends: HashMap<SmallVec<[u64; 4]>, Vec<Arc<UOp>>>,
34}
35
36impl ReduceContext {
37 pub fn register_end(&mut self, end: &Arc<UOp>) {
39 if let Op::End { ranges, .. } = end.op() {
40 let mut key: SmallVec<[u64; 4]> = ranges.iter().map(|r| r.id).collect();
41 key.sort_unstable();
42 self.range_to_ends.entry(key).or_default().push(end.clone());
43 }
44 }
45
46 pub fn merge_reduce_ends(&mut self, sources: &SmallVec<[Arc<UOp>; 4]>) -> Option<Arc<UOp>> {
54 #[allow(clippy::mutable_key_type)]
55 let subs = build_end_merge_subs(&self.range_to_ends);
56 self.range_to_ends.clear();
57 if subs.is_empty() {
58 return None;
59 }
60 Some(UOp::sink(sources.to_vec()).substitute(&subs))
61 }
62}
63
64#[allow(clippy::mutable_key_type)]
66fn build_end_merge_subs(range_to_ends: &HashMap<SmallVec<[u64; 4]>, Vec<Arc<UOp>>>) -> HashMap<UOpKey, Arc<UOp>> {
67 let mut subs = HashMap::new();
68 for ends in range_to_ends.values() {
69 if ends.len() <= 1 {
70 continue;
71 }
72 let computations: Vec<Arc<UOp>> = ends
73 .iter()
74 .map(|e| match e.op() {
75 Op::End { computation, .. } => computation.clone(),
76 _ => unreachable!(),
77 })
78 .collect();
79 let ranges = match ends[0].op() {
80 Op::End { ranges, .. } => ranges.clone(),
81 _ => unreachable!(),
82 };
83 let merged = UOp::group(computations).end(ranges);
84 for end in ends {
85 subs.insert(UOpKey(end.clone()), merged.clone());
86 }
87 }
88 subs
89}
90
91pub fn merge_sibling_ends(sink: &Arc<UOp>) -> Arc<UOp> {
101 let Op::Sink { sources } = sink.op() else { return sink.clone() };
102
103 let mut range_to_ends: HashMap<SmallVec<[u64; 4]>, Vec<Arc<UOp>>> = HashMap::new();
104 for node in sink.toposort() {
105 if let Op::End { ranges, .. } = node.op() {
106 let mut key: SmallVec<[u64; 4]> = ranges.iter().map(|r| r.id).collect();
107 key.sort_unstable();
108 range_to_ends.entry(key).or_default().push(node.clone());
109 }
110 }
111
112 #[allow(clippy::mutable_key_type)]
113 let subs = build_end_merge_subs(&range_to_ends);
114 if subs.is_empty() {
115 return sink.clone();
116 }
117 UOp::sink(sources.to_vec()).substitute(&subs)
118}
119
120use crate::rewrite::graph_rewrite;
121use crate::symbolic::patterns::sym;
122
123pub fn devectorize(ast: &Arc<UOp>) -> Arc<UOp> {
135 static COMBINED: LazyLock<TypedPatternMatcher> = LazyLock::new(|| {
136 sym()
137 + devectorize_patterns()
138 + load_store_folding_patterns()
139 + correct_load_store_patterns()
140 + load_store_indexing_patterns()
141 });
142 graph_rewrite(&*COMBINED, ast.clone(), &mut ())
143}
144
145pub fn bool_storage_patterns() -> &'static TypedPatternMatcher {
148 crate::cached_patterns! {
149 Store { index, value, ranges } if value.dtype().base().is_bool() => {
151 let uint8_dtype = value.dtype().with_base(ScalarDType::UInt8);
152 Some(index.store_with_ranges(value.cast(uint8_dtype), ranges.clone()))
153 },
154
155 load @ Load { buffer, index } if load.dtype().base().is_bool() => {
157 let uint8_dtype = load.dtype().with_base(ScalarDType::UInt8);
158 let uint8_load = UOp::load().buffer(buffer.clone()).index(index.clone()).dtype(uint8_dtype).call();
159 Some(uint8_load.cast(load.dtype()))
160 },
161
162 BitCast { src, dtype } if src.dtype().base().is_bool() || dtype.base().is_bool() => {
164 Some(src.cast(dtype.clone()))
165 },
166 }
167}
168
169#[derive(Debug, Clone)]
176pub struct Fp8DecompCtx {
177 pub from: ScalarDType,
178 pub to: ScalarDType,
179}
180
181fn rne(v: &Arc<UOp>, s: u32) -> Arc<UOp> {
184 let one = v.const_like(1);
185 let shifted = v.shr(&v.const_like(s));
186 let half_bit = v.shr(&v.const_like(s - 1)).and_(&one);
187 let remainder_mask = v.const_like((1i64 << (s - 1)) - 1);
188 let has_remainder = v.and_(&remainder_mask).ne(&v.const_like(0)).cast(v.dtype());
189 let lsb = shifted.and_(&one);
190 let round_up = half_bit.and_(&has_remainder.or_(&lsb));
191 shifted.try_add(&round_up).expect("rne: add failed")
192}
193
194fn f2f(v: &Arc<UOp>, fr: ScalarDType, to: ScalarDType) -> Arc<UOp> {
200 let (fe, fm) = fr.finfo();
201 let (te, tm) = to.finfo();
202 let fs = fr.bitsize();
203 let ts = to.bitsize();
204 let fb = fr.exponent_bias() as i64;
205 let tb = to.exponent_bias() as i64;
206 let fr_uint = DType::Scalar(fr.float_to_uint());
207 let to_uint = DType::Scalar(to.float_to_uint());
208
209 if fe <= te && fm < tm {
210 let sign_mask = v.const_like(1i64 << (fs - 1));
212 let sign = v.and_(&sign_mask).cast(to_uint.clone()).shl(&v.const_like(ts - fs).cast(to_uint.clone()));
213 let nosign_mask = v.const_like((1i64 << (fs - 1)) - 1);
214 let nosign = v.and_(&nosign_mask).cast(to_uint.clone());
215 let exp = nosign.shr(&nosign.const_like(fm));
216 let norm = nosign
217 .shl(&nosign.const_like(tm - fm))
218 .try_add(&nosign.const_like((tb - fb) << tm))
219 .expect("f2f: add failed");
220 let nan_val = nosign.shl(&nosign.const_like(tm - fm)).or_(&nosign.const_like(((1i64 << te) - 1) << tm));
221
222 let is_nan = if fr == ScalarDType::FP8E4M3 {
224 nosign.eq(&nosign.const_like((1i64 << (fm + fe)) - 1))
225 } else {
226 exp.eq(&exp.const_like((1i64 << fe) - 1))
227 };
228
229 let zero = nosign.const_like(0);
230 let exp_is_zero = exp.eq(&zero);
231 let inner = UOp::try_where(is_nan, nan_val, norm).expect("f2f: where failed");
232 let result = UOp::try_where(exp_is_zero, zero, inner).expect("f2f: where failed");
233 sign.or_(&result).bitcast(DType::Scalar(to))
234 } else if fe >= te && fm > tm {
235 let clamped = f2f_clamp(&v.bitcast(DType::Scalar(fr)), to);
237 let v = clamped.bitcast(fr_uint);
238 let sign = v.shr(&v.const_like(fs - ts)).and_(&v.const_like(1i64 << (ts - 1)));
239 let nosign_mask = v.const_like((1i64 << (fs - 1)) - 1);
240 let nosign = v.and_(&nosign_mask);
241 let norm = rne(&nosign, fm - tm)
242 .try_sub(&nosign.const_like((fb - tb) << tm))
243 .expect("f2f: sub failed")
244 .cast(to_uint.clone());
245
246 let exp_field = nosign.shr(&nosign.const_like(fm)).and_(&nosign.const_like((1i64 << fe) - 1));
247 let underflow = exp_field.lt(&exp_field.const_like(1 + fb - tb));
248
249 let nan_mantissa = if to == ScalarDType::FP8E4M3 {
250 sign.const_like((1i64 << tm) - 1).cast(to_uint.clone())
251 } else {
252 nosign.shr(&nosign.const_like(fm - tm)).and_(&nosign.const_like((1i64 << tm) - 1)).cast(to_uint.clone())
253 };
254 let nan_exp = sign.const_like(((1i64 << te) - 1) << tm).cast(to_uint.clone());
255 let nan = sign.cast(to_uint.clone()).or_(&nan_mantissa).or_(&nan_exp);
256
257 let is_nan = exp_field.eq(&exp_field.const_like((1i64 << fe) - 1));
258 let zero = sign.const_like(0).cast(to_uint.clone());
259 let normal = sign.cast(to_uint.clone()).or_(&UOp::try_where(underflow, zero, norm).expect("f2f: where failed"));
260 UOp::try_where(is_nan, nan, normal).expect("f2f: where failed")
261 } else {
262 panic!("f2f: unsupported conversion {fr:?} -> {to:?}")
263 }
264}
265
266fn f2f_clamp(val: &Arc<UOp>, dt: ScalarDType) -> Arc<UOp> {
269 let (e, m) = dt.finfo();
270 let (max_exp, max_man): (i64, i64) =
271 if dt == ScalarDType::FP8E4M3 { ((1 << e) - 1, (1 << m) - 2) } else { ((1 << e) - 2, (1 << m) - 1) };
272 let mx_f64 =
273 f64::powi(2.0, (max_exp - dt.exponent_bias() as i64) as i32) * (1.0 + max_man as f64 / (1i64 << m) as f64);
274 let mx = val.const_like(mx_f64);
275 let neg_mx = val.const_like(-mx_f64);
276
277 let sat = if dt.is_fp8() { mx.clone() } else { val.const_like(f64::INFINITY) };
279 let neg_sat = if dt.is_fp8() { neg_mx.clone() } else { val.const_like(f64::NEG_INFINITY) };
280
281 let is_nan = val.ne(val);
283 let below = val.lt(&neg_mx);
284 let above = mx.lt(val);
285 let clamped_above = UOp::try_where(above, sat, val.clone()).expect("f2f_clamp: where failed");
286 let clamped = UOp::try_where(below, neg_sat, clamped_above).expect("f2f_clamp: where failed");
287 UOp::try_where(is_nan, val.clone(), clamped).expect("f2f_clamp: where failed")
288}
289
290pub fn pm_float_decomp_store() -> crate::TypedPatternMatcher<Fp8DecompCtx> {
296 crate::patterns! {
297 @context Fp8DecompCtx;
298
299 Store { index, value, ranges }
302 if index.dtype().base() == ctx.from
303 => {
304 let target_float = DType::Scalar(ctx.to);
305 let target_uint = DType::Scalar(ctx.to.float_to_uint());
306 let float_val = value.cast(target_float);
308 let result = f2f(&float_val.bitcast(target_uint), ctx.to, ctx.from);
310 let uint8_ptr = index.dtype().with_ptr_base(DType::Scalar(ctx.from.float_to_uint()))?;
312 let new_index = index.with_dtype(uint8_ptr);
313 Some(new_index.store_with_ranges(result, ranges.clone()))
314 },
315 }
316}
317
318pub fn pm_float_decomp() -> crate::TypedPatternMatcher<Fp8DecompCtx> {
323 crate::patterns! {
324 @context Fp8DecompCtx;
325
326 x if matches!(x.op(), Op::Param { device: None, .. } | Op::DefineLocal(_) | Op::Index { .. })
328 && x.dtype().base() == ctx.from
329 => {
330 let uint8_ptr = x.dtype().with_ptr_base(DType::Scalar(ctx.from.float_to_uint()))?;
331 Some(x.with_dtype(uint8_ptr))
332 },
333
334 load @ Load { buffer, index } if load.dtype().base() == ctx.from => {
336 let uint_dtype = DType::Scalar(ctx.from.float_to_uint());
337 let uint_load = UOp::load().buffer(buffer.clone()).index(index.clone())
338 .dtype(uint_dtype).call();
339 Some(f2f(&uint_load, ctx.from, ctx.to))
340 },
341
342 x @ Cast { src: val, .. } if x.dtype().base() == ctx.from => {
346 let target = DType::Scalar(ctx.to);
347 let target_uint = DType::Scalar(ctx.to.float_to_uint());
348 let float_val = val.cast(target);
349 let fp8_bytes = f2f(&float_val.bitcast(target_uint.clone()), ctx.to, ctx.from);
351 Some(f2f(&fp8_bytes, ctx.from, ctx.to))
353 },
354
355 x if !matches!(x.op(), Op::BitCast { .. })
357 && x.dtype().is_float()
358 && x.dtype().base() == ctx.from
359 => {
360 let target_dtype = DType::Scalar(ctx.to);
361 let new_dtype = if x.dtype().vcount() > 1 {
362 target_dtype.vec(x.dtype().vcount())
363 } else {
364 target_dtype.clone()
365 };
366 let new_sources: Vec<Arc<UOp>> = x.op().sources().iter().map(|s| {
367 if s.dtype().base() == ctx.from {
368 s.cast(target_dtype.clone())
369 } else {
370 s.clone()
371 }
372 }).collect();
373 Some(x.with_sources(new_sources).with_dtype(new_dtype))
374 },
375 }
376}
377
378pub fn pm_render() -> &'static TypedPatternMatcher {
381 crate::cached_patterns! {
382 c @ Const(_) if c.dtype().vcount() > 1 => |c| {
384 let vcount = c.dtype().vcount();
385 let Op::Const(cv) = c.op() else { return None };
386 let scalar_const = UOp::const_(c.dtype().scalar_dtype(), cv.0);
387 let elements: SmallVec<[Arc<UOp>; 4]> = (0..vcount).map(|_| scalar_const.clone()).collect();
388 Some(UOp::vectorize(elements))
389 },
390
391 vc @ VConst { values } => |vc, values| {
393 let scalar_dtype = vc.dtype().scalar_dtype();
394 let elements: SmallVec<[Arc<UOp>; 4]> = values.iter()
395 .map(|v| UOp::const_(scalar_dtype.clone(), *v))
396 .collect();
397 Some(UOp::vectorize(elements))
398 },
399
400 Cat { sources } if sources.len() == 1 => Some(sources[0].clone()),
402 Cat { sources } => {
403 let elements: SmallVec<[Arc<UOp>; 4]> = sources.iter()
404 .flat_map(|src| {
405 let n = src.dtype().vcount();
406 (0..n).map(move |i| if n == 1 { src.clone() } else { src.gep(vec![i]) })
407 })
408 .collect();
409 Some(UOp::vectorize(elements))
410 },
411
412 Gep { vector, indices } if vector.dtype().vcount() == 1 && indices.len() == 1 && indices[0] == 0
414 ~> |vector| Arc::clone(vector),
415
416 Gep { vector, indices } if is_identity_gep(vector, indices) => Some(vector.clone()),
418
419 Gep { vector: Vectorize { elements }, indices } => {
421 let extracted: SmallVec<[Arc<UOp>; 4]> = indices.iter()
422 .filter_map(|&i| elements.get(i).cloned())
423 .collect();
424 if extracted.len() != indices.len() { return None; }
425 Some(if extracted.len() == 1 { extracted[0].clone() } else { UOp::vectorize(extracted) })
426 },
427
428 Gep { vector, indices } if indices.len() > 1 => |vector, indices| {
430 let geps: SmallVec<[Arc<UOp>; 4]> = indices.iter()
431 .map(|&i| vector.gep(vec![i]))
432 .collect();
433 Some(UOp::vectorize(geps))
434 },
435
436 Vectorize { elements } if elements.len() == 1 => Some(elements[0].clone()),
438 PtrCat { sources } if sources.len() == 1 => Some(sources[0].clone()),
439
440 load @ Load { index, alt: None, .. } if has_gate(index) => |load, index| {
447 let alt_value = load.const_like(ConstValue::Int(0));
448 Some(UOp::load().buffer(load.load_buffer()?).index(index.clone()).alt(alt_value).dtype(load.dtype()).call())
449 },
450
451 Where(cond, load @ Load { index, .. }, alt)
457 if index_has_gate_matching(index, cond)
458 => |cond, load, index, alt| {
459 let casted_alt = cast_alt_avoiding_roundtrip(alt, &load.dtype());
460 let new_load = UOp::load()
461 .buffer(load.load_buffer()?)
462 .index(index.clone())
463 .alt(casted_alt)
464 .dtype(load.dtype())
465 .call();
466 Some(new_load.cast(alt.dtype()))
467 },
468
469 Where(cond, alt, load @ Load { index, .. })
473 if index_has_inverted_gate_matching(index, cond)
474 => |cond, alt, load, index| {
475 let casted_alt = cast_alt_avoiding_roundtrip(alt, &load.dtype());
476 let new_load = UOp::load()
477 .buffer(load.load_buffer()?)
478 .index(index.clone())
479 .alt(casted_alt)
480 .dtype(load.dtype())
481 .call();
482 Some(new_load.cast(alt.dtype()))
483 },
484 }
485}
486
487fn cast_alt_avoiding_roundtrip(alt: &Arc<UOp>, load_dtype: &DType) -> Arc<UOp> {
491 if let Op::Cast { src: inner, .. } = alt.op()
492 && inner.dtype() == *load_dtype
493 {
494 return inner.clone();
495 }
496 alt.cast(load_dtype.clone())
497}
498
499fn is_identity_gep(vector: &Arc<UOp>, indices: &[usize]) -> bool {
501 let vcount = vector.dtype().vcount();
502 indices.len() == vcount && indices.iter().enumerate().all(|(i, &j)| i == j)
503}
504
505fn has_gate(index: &Arc<UOp>) -> bool {
507 match index.op() {
508 Op::Index { gate: Some(_), .. } => true,
509 Op::Cast { src, .. } => has_gate(src),
510 _ => false,
511 }
512}
513
514fn index_has_gate_matching(index: &Arc<UOp>, cond: &Arc<UOp>) -> bool {
516 match index.op() {
517 Op::Index { gate: Some(g), .. } => Arc::ptr_eq(g, cond),
518 Op::Cast { src, .. } => index_has_gate_matching(src, cond),
519 _ => false,
520 }
521}
522
523fn index_has_inverted_gate_matching(index: &Arc<UOp>, cond: &Arc<UOp>) -> bool {
533 match index.op() {
534 Op::Index { gate: Some(g), .. } => is_negation_of(g, cond),
535 Op::Cast { src, .. } => index_has_inverted_gate_matching(src, cond),
536 _ => false,
537 }
538}
539
540fn is_negation_of(gate: &Arc<UOp>, cond: &Arc<UOp>) -> bool {
542 if let Op::Unary(UnaryOp::Not, inner) = gate.op()
544 && Arc::ptr_eq(inner, cond)
545 {
546 return true;
547 }
548
549 if let Op::Binary(BinaryOp::Lt, gate_lhs, gate_rhs) = gate.op()
552 && let Op::Binary(BinaryOp::Lt, cond_lhs, cond_rhs) = cond.op()
553 {
554 if Arc::ptr_eq(gate_rhs, cond_lhs)
556 && let (Op::Const(gate_cv), Op::Const(cond_cv)) = (gate_lhs.op(), cond_rhs.op())
557 && is_const_minus_one(&cond_cv.0, &gate_cv.0)
558 {
559 return true;
560 }
561 if Arc::ptr_eq(gate_lhs, cond_rhs)
563 && let (Op::Const(cond_cv), Op::Const(gate_cv)) = (cond_lhs.op(), gate_rhs.op())
564 && is_const_minus_one(&cond_cv.0, &gate_cv.0)
565 {
566 return true;
567 }
568 }
569
570 false
571}
572
573fn is_const_minus_one(a: &ConstValue, b: &ConstValue) -> bool {
575 match (a, b) {
576 (ConstValue::Int(av), ConstValue::Int(bv)) => av.checked_sub(1) == Some(*bv),
577 (ConstValue::UInt(av), ConstValue::UInt(bv)) => av.checked_sub(1) == Some(*bv),
578 _ => false,
579 }
580}
581
582fn devectorize_alu(alu: &Arc<UOp>) -> Option<Arc<UOp>> {
595 let vcount = alu.dtype().vcount();
596 if vcount <= 1 {
597 return None;
598 }
599
600 if let Op::Ternary(TernaryOp::Where, _, _, f) = alu.op()
603 && UOp::is_invalid_marker(f)
604 {
605 return None;
606 }
607
608 let scalar_dtype = alu.dtype().scalar_dtype();
609 let sources = alu.op().sources();
610
611 let elements: SmallVec<[Arc<UOp>; 4]> = (0..vcount)
612 .map(|i| {
613 let new_sources: Vec<Arc<UOp>> =
615 sources.iter().map(|s| if s.dtype().vcount() > 1 { s.gep(vec![i]) } else { s.clone() }).collect();
616
617 match alu.op() {
621 Op::Cast { .. } => new_sources[0].cast(scalar_dtype.clone()),
622 Op::BitCast { .. } => new_sources[0].bitcast(scalar_dtype.clone()),
623 _ => alu.replace().dtype(scalar_dtype.clone()).src(new_sources).call(),
624 }
625 })
626 .collect();
627
628 Some(UOp::vectorize(elements))
629}
630
631#[allow(unused_variables)]
634pub fn no_vectorized_alu() -> &'static TypedPatternMatcher {
635 crate::cached_patterns! {
636 for op in binary [*] {
638 alu @ op(_, _) if alu.dtype().vcount() > 1 => devectorize_alu(alu),
639 },
640 for op in unary [*] {
642 alu @ op(_) if alu.dtype().vcount() > 1 => devectorize_alu(alu),
643 },
644 for op in ternary [*] {
646 alu @ op(_, _, _) if alu.dtype().vcount() > 1 => devectorize_alu(alu),
647 },
648 alu @ Cast { src: _, .. } if alu.dtype().vcount() > 1 => devectorize_alu(alu),
650 alu @ BitCast { src: _, .. } if alu.dtype().vcount() > 1 => devectorize_alu(alu),
651 }
652}
653
654pub fn devectorize_patterns() -> &'static TypedPatternMatcher {
660 use std::sync::LazyLock;
661 static CACHED: LazyLock<TypedPatternMatcher> = LazyLock::new(|| {
662 cast_after_pattern() + no_vectorized_alu() + no_vectorized_wmma() + devectorize_buf_and_index_patterns()
663 });
664 &CACHED
665}
666
667fn no_vectorized_wmma() -> &'static TypedPatternMatcher {
669 crate::cached_patterns! {
670 wmma @ Wmma { a, b, c, metadata } if wmma.dtype().vcount() > wmma_expected_size(metadata)
671 => devectorize_wmma(wmma, a, b, c, metadata),
672 }
673}
674
675fn wmma_expected_size(metadata: &WmmaMetadata) -> usize {
676 metadata.upcast_axes.c.iter().map(|(_, size)| size).product::<usize>().max(1)
677}
678
679fn devectorize_wmma(
680 wmma: &Arc<UOp>,
681 a: &Arc<UOp>,
682 b: &Arc<UOp>,
683 c: &Arc<UOp>,
684 metadata: &WmmaMetadata,
685) -> Option<Arc<UOp>> {
686 let out_sz = wmma_expected_size(metadata);
687 if wmma.dtype().vcount() == out_sz {
688 return None;
689 }
690
691 let sources: [&Arc<UOp>; 3] = [a, b, c];
695 let tsrcs: Vec<Vec<Arc<UOp>>> = sources
696 .iter()
697 .enumerate()
698 .map(|(i, src)| {
699 let ssz = metadata.upcast_axes.source_size(i);
700 let n = src.dtype().vcount();
701 (0..n).step_by(ssz).map(|g| src.gep((g..g + ssz.min(n - g)).collect())).collect()
702 })
703 .collect();
704
705 let num_groups = tsrcs[0].len();
707 if tsrcs.iter().any(|t| t.len() != num_groups) {
708 tracing::warn!("WMMA devectorization: mismatched source group counts");
709 return None;
710 }
711
712 let wmma_ex: SmallVec<[Arc<UOp>; 4]> = (0..num_groups)
714 .flat_map(|g| {
715 let w = UOp::wmma(tsrcs[0][g].clone(), tsrcs[1][g].clone(), tsrcs[2][g].clone(), metadata.clone());
716 (0..out_sz).map(move |i| w.gep(vec![i]))
717 })
718 .collect();
719
720 Some(UOp::vectorize(wmma_ex))
721}
722
723fn cast_after_pattern() -> &'static TypedPatternMatcher {
725 crate::cached_patterns! {
726 After { passthrough: Cast { src, dtype }, deps }
727 => |src, dtype, deps| {
728 let new_after = src.after(deps.clone());
729 Some(new_after.cast(dtype.clone()))
730 },
731 }
732}
733
734fn devectorize_buf_and_index_patterns() -> &'static TypedPatternMatcher {
741 crate::cached_patterns! {
742 def if matches!(def.op(), Op::DefineLocal(_) | Op::DefineReg { .. })
744 && def.ptrdtype().is_some_and(|(base, _, _)| base.vcount() > 1)
745 => no_vectorized_buf(def),
746
747 Index { buffer: Cast { src: buf, dtype: cast_dtype }, indices, gate }
750 if is_vectorized_local_reg_cast(buf, cast_dtype)
751 => no_vectorized_index(buf, indices, gate, cast_dtype),
752
753 Index { buffer: Vectorize { elements }, indices, gate }
755 if is_vectorized_broadcast_cast(elements)
756 => {
757 let first = elements.first()?;
758 let Op::Cast { src: buf, dtype: DType::Ptr { base, .. } } = first.op() else { return None };
759 let idx = indices.first()?;
760 no_vectorized_index_precnt(buf, idx, gate, base.vcount(), &vec![0; elements.len()])
761 },
762
763 Index { buffer: Gep { vector: Cast { src: buf, dtype: cast_dtype }, indices: gep_indices }, indices, gate }
765 if is_vectorized_local_reg_cast(buf, cast_dtype)
766 => {
767 let DType::Ptr { base, .. } = cast_dtype else { return None };
768 let idx = indices.first()?;
769 no_vectorized_index_precnt(buf, idx, gate, base.vcount(), gep_indices)
770 },
771 }
772}
773
774fn is_vectorized_local_reg_cast(buf: &Arc<UOp>, cast_dtype: &DType) -> bool {
775 matches!(cast_dtype, DType::Ptr { base, .. } if base.vcount() > 1) && is_define_local_or_reg_or_after(buf)
776}
777
778fn is_vectorized_broadcast_cast(elements: &SmallVec<[Arc<UOp>; 4]>) -> bool {
779 elements.first().is_some_and(|f| {
780 matches!(f.op(), Op::Cast { dtype: DType::Ptr { base, .. }, src }
781 if base.vcount() > 1 && is_define_local_or_reg_or_after(src))
782 })
783}
784
785fn is_define_local_or_reg_or_after(uop: &Arc<UOp>) -> bool {
787 matches!(uop.unwrap_after().op(), Op::DefineLocal(_) | Op::DefineReg { .. })
788}
789
790fn no_vectorized_buf(buf: &Arc<UOp>) -> Option<Arc<UOp>> {
792 let (base, addrspace, size) = buf.ptrdtype()?;
793 let vcount = base.vcount();
794 if vcount <= 1 {
795 return None;
796 }
797
798 let scalar_base = base.base();
799 let new_size = size.map(|s| s * vcount);
800 let scalar_ptr_dtype =
801 DType::Ptr { base: Box::new(DType::Scalar(scalar_base)), addrspace, size: new_size, vcount: 1 };
802
803 let scalar_def = buf.with_dtype(scalar_ptr_dtype);
804 Some(scalar_def.cast(buf.dtype()))
805}
806
807fn no_vectorized_index(
813 buf: &Arc<UOp>,
814 indices: &SmallVec<[Arc<UOp>; 4]>,
815 gate: &Option<Arc<UOp>>,
816 cast_dtype: &DType,
817) -> Option<Arc<UOp>> {
818 let idx = indices.first()?;
819 let DType::Ptr { base, .. } = cast_dtype else { return None };
820 let cnt = base.vcount();
821 if cnt <= 1 {
822 return None;
823 }
824
825 let idx_vcount = idx.dtype().vcount();
826 let total = cnt * idx_vcount;
827 let buf_broadcast = buf.broadcast(total);
828
829 let final_idx = if idx_vcount == 1 {
830 let idx_broadcast = idx.broadcast(cnt);
832 let cnt_broadcast = idx.const_like(cnt as i64).broadcast(cnt);
833 idx_broadcast.mul(&cnt_broadcast).add(&create_index_vector(0..cnt as i64))
834 } else {
835 let elements: SmallVec<[Arc<UOp>; 4]> = (0..idx_vcount)
838 .flat_map(|i| {
839 let lane = idx.gep(vec![i]);
840 let cnt_const = UOp::const_(lane.dtype(), ConstValue::Int(cnt as i64));
841 let scaled = lane.mul(&cnt_const);
842 (0..cnt).map(move |j| scaled.add(&UOp::const_(scaled.dtype(), ConstValue::Int(j as i64))))
843 })
844 .collect();
845 UOp::vectorize(elements)
846 };
847
848 let expanded_gate = if idx_vcount > 1 {
850 gate.as_ref().map(|g| {
851 if g.dtype().vcount() > 1 {
852 let elements: SmallVec<[Arc<UOp>; 4]> = (0..idx_vcount)
853 .flat_map(|i| {
854 let lane = g.gep(vec![i]);
855 std::iter::repeat_n(lane, cnt)
856 })
857 .collect();
858 UOp::vectorize(elements)
859 } else {
860 g.broadcast(total)
861 }
862 })
863 } else {
864 gate.clone()
865 };
866
867 Some(
868 UOp::index()
869 .buffer(buf_broadcast)
870 .indices(vec![final_idx])
871 .maybe_gate(expanded_gate)
872 .ptr(true)
873 .call()
874 .expect("ICE unable to create index"),
875 )
876}
877
878fn create_index_vector(values: impl IntoIterator<Item = i64>) -> Arc<UOp> {
879 let elements: SmallVec<[Arc<UOp>; 4]> = values.into_iter().map(UOp::index_const).collect();
880 UOp::vectorize(elements)
881}
882
883fn no_vectorized_index_precnt(
888 buf: &Arc<UOp>,
889 idx: &Arc<UOp>,
890 gate: &Option<Arc<UOp>>,
891 cnt: usize,
892 input_gep: &[usize],
893) -> Option<Arc<UOp>> {
894 let precnt = input_gep.len();
895 let idx_vcount = idx.dtype().vcount();
896
897 if idx_vcount == 1 {
898 let total = cnt * precnt;
900 let gep_arg: Vec<usize> = (0..cnt).flat_map(|_| 0..precnt).collect();
901 let sum_arg = (0..cnt).flat_map(|i| input_gep.iter().map(move |&y| (i + y) as i64));
902
903 let buf_broadcast = buf.broadcast(total);
904 let final_idx =
905 idx.gep(gep_arg).mul(&idx.const_like(cnt as i64).broadcast(total)).add(&create_index_vector(sum_arg));
906
907 Some(
908 UOp::index()
909 .buffer(buf_broadcast)
910 .indices(vec![final_idx])
911 .maybe_gate(gate.clone())
912 .ptr(true)
913 .call()
914 .expect("ICE: unable to create index"),
915 )
916 } else {
917 let per_lane = cnt * precnt;
919 let total = per_lane * idx_vcount;
920
921 let buf_broadcast = buf.broadcast(total);
922 let elements: SmallVec<[Arc<UOp>; 4]> = (0..idx_vcount)
923 .flat_map(|i| {
924 let lane = idx.gep(vec![i]);
925 let cnt_const = UOp::const_(lane.dtype(), ConstValue::Int(cnt as i64));
926 let scaled = lane.mul(&cnt_const);
927 (0..cnt).flat_map(move |c| {
928 let s = scaled.clone();
929 input_gep.iter().map(move |&y| s.add(&UOp::const_(s.dtype(), ConstValue::Int((c + y) as i64))))
930 })
931 })
932 .collect();
933 let final_idx = UOp::vectorize(elements);
934
935 let expanded_gate = gate.as_ref().map(|g| {
936 if g.dtype().vcount() > 1 {
937 let elements: SmallVec<[Arc<UOp>; 4]> = (0..idx_vcount)
938 .flat_map(|i| {
939 let lane = g.gep(vec![i]);
940 std::iter::repeat_n(lane, per_lane)
941 })
942 .collect();
943 UOp::vectorize(elements)
944 } else {
945 g.broadcast(total)
946 }
947 });
948
949 Some(
950 UOp::index()
951 .buffer(buf_broadcast)
952 .indices(vec![final_idx])
953 .maybe_gate(expanded_gate)
954 .ptr(true)
955 .call()
956 .expect("ICE: unable to create index"),
957 )
958 }
959}
960
961pub fn load_store_indexing_patterns() -> &'static TypedPatternMatcher {
975 crate::cached_patterns! {
976 index @ Index { buffer, indices, gate: Some(g) }
981 if matches!(g.op(), Op::Const(cv) if matches!(cv.0, ConstValue::Bool(true)))
982 ~> UOp::new(Op::Index { buffer: buffer.clone(), indices: indices.clone(), gate: None }, index.dtype())
983 }
984}
985
986pub fn pm_add_loads() -> &'static TypedPatternMatcher {
992 crate::cached_patterns! {
993 idx @ Index { buffer, .. } if !is_ptr_or_image_dtype(&idx.dtype()) => {
996 let new_idx = idx.with_dtype(buffer.dtype());
997 Some(UOp::load().buffer(buffer.clone()).index(new_idx).dtype(idx.dtype().scalar_dtype()).call())
998 },
999
1000 Store { index: Load { index: inner_idx, .. }, value, ranges }
1003 => Some(inner_idx.store_with_ranges(value.clone(), ranges.clone())),
1004 }
1005}
1006
1007fn is_ptr_or_image_dtype(dtype: &DType) -> bool {
1008 matches!(dtype, DType::Ptr { .. } | DType::Image { .. })
1009}
1010
1011pub fn pm_wmma_accumulate() -> &'static TypedPatternMatcher {
1018 crate::cached_patterns! {
1019 Add(wmma @ Wmma { a, b, c, metadata }, add) => |wmma, a, b, c, metadata, add| {
1022 if wmma.dtype() != add.dtype() {
1024 return None;
1025 }
1026 let new_c = c.add(add);
1027 Some(UOp::wmma(a.clone(), b.clone(), new_c, metadata.clone()))
1028 },
1029
1030 Add(add, wmma @ Wmma { a, b, c, metadata }) => |wmma, add, a, b, c, metadata| {
1032 if wmma.dtype() != add.dtype() {
1033 return None;
1034 }
1035 let new_c = c.add(add);
1036 Some(UOp::wmma(a.clone(), b.clone(), new_c, metadata.clone()))
1037 },
1038 }
1039}
1040
1041pub fn load_store_folding_patterns() -> &'static TypedPatternMatcher {
1047 crate::cached_patterns! {
1048 index if is_vector_index(index) => expand_index_to_vectorize(index),
1050
1051 midx @ Vectorize { elements } if elements.iter().all(|e| matches!(e.op(), Op::Index { .. }))
1053 => fold_expanded_index(midx),
1054
1055 load @ Load { buffer, index: Gep { vector, indices } }
1057 => move_gep_after_load(load, buffer, vector, indices),
1058
1059 Store { index: Gep { vector, indices }, value, ranges }
1061 => move_gep_on_store(vector, indices, value, ranges),
1062
1063 load @ Load { buffer, index: ptrcat @ PtrCat { sources } }
1065 => distribute_ptrcat_load(load, buffer, ptrcat, sources),
1066
1067 Store { index: PtrCat { sources }, value, ranges }
1069 => distribute_ptrcat_store(sources, value, ranges),
1070 }
1071}
1072
1073pub fn correct_load_store_patterns() -> &'static TypedPatternMatcher {
1079 crate::cached_patterns! {
1080 ls @ Load { index: Cast { src: idx @ Index { buffer: _, .. }, .. }, .. }
1082 => split_load_store(ls, idx),
1083
1084 ls @ Store { index: Cast { src: idx @ Index { buffer: _, .. }, .. }, .. }
1085 => split_load_store(ls, idx),
1086
1087 ls @ Load { buffer: _, index: _, alt: _ } => image_fixup(ls),
1089 ls @ Store { index: _, value: _, ranges: _ } => image_fixup(ls),
1090 }
1091}
1092
1093fn is_define_or_after(uop: &Arc<UOp>) -> bool {
1098 matches!(uop.unwrap_after().op(), Op::DefineLocal(_) | Op::DefineReg { .. } | Op::Param { device: None, .. })
1099}
1100
1101fn is_vector_index(uop: &Arc<UOp>) -> bool {
1104 let Op::Index { buffer, indices, .. } = uop.op() else { return false };
1105 let Some(idx) = indices.first() else { return false };
1106 if idx.dtype().vcount() <= 1 {
1107 return false;
1108 }
1109 let Op::Vectorize { elements } = buffer.op() else { return false };
1110 !elements.is_empty() && elements.iter().all(is_define_or_after)
1111}
1112
1113fn move_gep_after_load(
1125 load: &Arc<UOp>,
1126 buffer: &Arc<UOp>,
1127 gep_inner: &Arc<UOp>,
1128 gep_indices: &[usize],
1129) -> Option<Arc<UOp>> {
1130 let new_dtype = load.dtype().scalar_dtype().vec(gep_indices.len());
1131 let inner_load = load.replace().dtype(new_dtype).src(vec![buffer.clone(), gep_inner.clone()]).call();
1132 Some(inner_load.gep(gep_indices.to_vec()))
1133}
1134
1135fn move_gep_on_store(
1137 gep_inner: &Arc<UOp>,
1138 gep_indices: &[usize],
1139 value: &Arc<UOp>,
1140 ranges: &SmallVec<[Arc<UOp>; 4]>,
1141) -> Option<Arc<UOp>> {
1142 let mut inverse_map: Vec<(usize, usize)> = gep_indices.iter().enumerate().map(|(i, &x)| (x, i)).collect();
1144 inverse_map.sort_by_key(|&(x, _)| x);
1145 let inverse_indices: Vec<usize> = inverse_map.iter().map(|&(_, i)| i).collect();
1146
1147 let reordered_value = value.gep(inverse_indices);
1148 Some(gep_inner.store_with_ranges(reordered_value, ranges.clone()))
1149}
1150
1151fn expand_index_to_vectorize(index: &Arc<UOp>) -> Option<Arc<UOp>> {
1160 let Op::Index { buffer, indices, gate } = index.op() else { return None };
1161 assert!(indices.len() <= 1, "ICE: expand_index_to_vectorize called with multi-index INDEX (len={})", indices.len());
1162 let vec = indices.first()?;
1163 let count = vec.dtype().vcount();
1164
1165 let buf = if let Op::Vectorize { elements } = buffer.op() { elements.first()?.clone() } else { buffer.clone() };
1166
1167 let scalar_indices: Vec<_> = (0..count)
1168 .map(|i| {
1169 let lane_gate = gate.as_ref().map(|g| if g.dtype().vcount() > 1 { g.gep(vec![i]) } else { g.clone() });
1170 UOp::index()
1171 .buffer(buf.clone())
1172 .indices(vec![vec.gep(vec![i])])
1173 .maybe_gate(lane_gate)
1174 .ptr(true)
1175 .call()
1176 .expect("ICE: unable to create index")
1177 })
1178 .collect();
1179
1180 Some(UOp::vectorize(scalar_indices.into()))
1181}
1182
1183fn fold_expanded_index(midx: &Arc<UOp>) -> Option<Arc<UOp>> {
1188 let Op::Vectorize { elements: sources } = midx.op() else { return None };
1189 let count = sources.len();
1190 if count == 0 {
1191 return None;
1192 }
1193
1194 let first_buf = match sources[0].op() {
1196 Op::Index { buffer, .. } => buffer,
1197 _ => return None,
1198 };
1199 if !sources.iter().all(|s| matches!(s.op(), Op::Index { buffer, .. } if Arc::ptr_eq(buffer, first_buf))) {
1200 return None;
1201 }
1202 let buf = first_buf;
1203
1204 struct LaneData {
1206 valid: Arc<UOp>,
1207 root: Arc<UOp>,
1208 offset: i64,
1209 gate_id: u64,
1210 }
1211 let mut lane_data: Vec<(usize, LaneData)> = Vec::with_capacity(count);
1212
1213 for (lane, idx_op) in sources.iter().enumerate() {
1214 let Op::Index { indices: simp_indices, gate: lane_gate, .. } = idx_op.op() else { continue };
1215 let idx = simp_indices.first()?.get_idx();
1216 let valid = simp_indices.first()?.get_valid();
1217 let gate_id = lane_gate.as_ref().map_or(u64::MAX, |g| g.id);
1218
1219 let (root, offset) = match idx.op() {
1220 Op::Invalid => (UOp::invalid_marker(), 0),
1221 Op::Binary(BinaryOp::Add, l, r) if matches!(r.op(), Op::Const(cv) if matches!(cv.0, ConstValue::Int(_))) => {
1222 let Op::Const(cv) = r.op() else { unreachable!() };
1223 let ConstValue::Int(off) = cv.0 else { unreachable!() };
1224 (l.clone(), off)
1225 }
1226 Op::Binary(BinaryOp::Add, l, r) if matches!(l.op(), Op::Const(cv) if matches!(cv.0, ConstValue::Int(_))) => {
1227 let Op::Const(cv) = l.op() else { unreachable!() };
1228 let ConstValue::Int(off) = cv.0 else { unreachable!() };
1229 (r.clone(), off)
1230 }
1231 Op::Const(cv) if matches!(cv.0, ConstValue::Int(_)) => {
1232 let ConstValue::Int(off) = cv.0 else { unreachable!() };
1233 (UOp::index_const(0), off)
1234 }
1235 _ => (idx.clone(), 0),
1236 };
1237
1238 lane_data.push((lane, LaneData { valid, root, offset, gate_id }));
1239 }
1240
1241 let mut offsets_by_root: HashMap<(u64, u64, u64), HashMap<i64, Vec<usize>>> = HashMap::new();
1243 for (lane, data) in &lane_data {
1244 let key = (data.valid.id, data.root.id, data.gate_id);
1245 offsets_by_root.entry(key).or_default().entry(data.offset).or_default().push(*lane);
1246 }
1247
1248 let mut ret = Vec::new();
1250 let mut idxs: Vec<Option<usize>> = vec![None; count];
1251 let mut global_offset = 0;
1252
1253 for offsets in offsets_by_root.values() {
1254 let groups = group_consecutive_offsets_from_map(offsets);
1255 for grp in groups {
1256 let lidx = sources[offsets[&grp[0]][0]].clone();
1257 let ptr = if grp.len() > 1 { lidx.cast(make_vec_ptr_dtype(buf, grp.len())) } else { lidx };
1258 for (i, &offset) in grp.iter().enumerate() {
1259 for &lane in &offsets[&offset] {
1260 idxs[lane] = Some(global_offset + i);
1261 }
1262 }
1263 ret.push(ptr);
1264 global_offset += grp.len();
1265 }
1266 }
1267
1268 if idxs.iter().any(|x| x.is_none()) {
1269 return None;
1270 }
1271
1272 let DType::Ptr { base, addrspace, size, .. } = buf.dtype().clone() else { return None };
1273 let scalar_ptr = DType::Ptr { base: Box::new(DType::Scalar(base.scalar()?)), addrspace, size, vcount: 1 };
1274 let ptrcat_dtype = scalar_ptr.vec(global_offset);
1275 let ptrcat = UOp::ptrcat().sources(ret).dtype(ptrcat_dtype).call();
1276 let gep_indices: Vec<usize> = idxs.into_iter().map(|x| x.unwrap()).collect();
1277
1278 Some(ptrcat.gep(gep_indices))
1279}
1280
1281fn group_consecutive_offsets_from_map(offsets_map: &HashMap<i64, Vec<usize>>) -> Vec<Vec<i64>> {
1287 if offsets_map.is_empty() {
1288 return vec![];
1289 }
1290
1291 let sorted: Vec<_> = offsets_map.keys().copied().sorted().collect();
1292 sorted
1293 .iter()
1294 .copied()
1295 .enumerate()
1296 .chunk_by(|(idx, offset)| offset - (*idx as i64))
1297 .into_iter()
1298 .map(|(_, group)| group.map(|(_, offset)| offset).collect())
1299 .collect()
1300}
1301
1302fn make_vec_ptr_dtype(buffer: &Arc<UOp>, vec_len: usize) -> DType {
1303 let (base_dtype, addrspace) = buffer
1304 .ptrdtype()
1305 .map(|(base, addrspace, _)| (base.base(), addrspace))
1306 .unwrap_or_else(|| (buffer.dtype().base(), AddrSpace::Global));
1307 let vec_dtype = DType::Vector { scalar: base_dtype, count: vec_len };
1308 DType::Ptr { base: Box::new(vec_dtype), addrspace, size: Some(vec_len), vcount: 1 }
1309}
1310
1311fn distribute_ptrcat_load(
1324 load: &Arc<UOp>,
1325 buffer: &Arc<UOp>,
1326 ptrcat: &Arc<UOp>,
1327 sources: &[Arc<UOp>],
1328) -> Option<Arc<UOp>> {
1329 let loads: Vec<Arc<UOp>> = sources
1330 .iter()
1331 .enumerate()
1332 .map(|(i, ptr)| {
1333 let load_dtype = match ptr.dtype() {
1334 DType::Ptr { base, .. } => base.as_ref().clone(),
1335 other => other.clone(),
1336 };
1337 let scalar_buf = match buffer.op() {
1343 Op::Vectorize { elements, .. } => elements.get(i).cloned().unwrap_or_else(|| buffer.clone()),
1344 _ => buffer.clone(),
1345 };
1346 let alt = match load.op() {
1347 Op::Load { alt, .. } => alt.clone(),
1348 _ => None,
1349 };
1350 UOp::load().buffer(scalar_buf).index(ptr.clone()).maybe_alt(alt).dtype(load_dtype).call()
1351 })
1352 .collect();
1353
1354 let cat_dtype = DType::Scalar(ptrcat.dtype().base()).vec(ptrcat.dtype().vcount());
1355 Some(UOp::cat().sources(loads).dtype(cat_dtype).call())
1356}
1357
1358fn distribute_ptrcat_store(
1360 sources: &[Arc<UOp>],
1361 value: &Arc<UOp>,
1362 ranges: &SmallVec<[Arc<UOp>; 4]>,
1363) -> Option<Arc<UOp>> {
1364 let value_vcount = value.dtype().vcount();
1365 let mut stores = Vec::new();
1366 let mut offset = 0usize;
1367
1368 for ptr in sources.iter() {
1369 let ptr_count = ptr_element_count(ptr);
1370 debug_assert!(offset + ptr_count <= value_vcount, "PTRCAT size mismatch");
1371 let gep_indices: Vec<usize> = (offset..offset + ptr_count).collect();
1372 let store_value = value.gep(gep_indices);
1373 stores.push(ptr.store_with_ranges(store_value, ranges.clone()));
1374 offset += ptr_count;
1375 }
1376
1377 Some(UOp::group(stores.into_iter().collect()))
1378}
1379
1380fn ptr_element_count(ptr: &Arc<UOp>) -> usize {
1388 match ptr.dtype() {
1389 DType::Ptr { base, .. } => base.vcount(),
1390 _ => 1,
1391 }
1392}
1393
1394fn split_load_store(ls: &Arc<UOp>, idx: &Arc<UOp>) -> Option<Arc<UOp>> {
1400 let Op::Index { buffer: buf, indices, .. } = idx.op() else { return None };
1401
1402 let sz = match ls.op() {
1406 Op::Load { index, .. } | Op::Store { index, .. } => ptr_element_count(index),
1407 _ => return None,
1408 };
1409 if sz == 1 {
1410 return None;
1411 }
1412
1413 let buf_dtype = buf.dtype();
1415 static IS_AMX: std::sync::LazyLock<bool> =
1416 std::sync::LazyLock::new(|| std::env::var("MOROK_AMX").is_ok_and(|v| v == "1"));
1417 let is_amx = *IS_AMX;
1418
1419 fn is_amx_tc_reg_ptr(dtype: &DType, sz: usize) -> bool {
1423 sz >= 16
1424 && dtype.base().is_float()
1425 && matches!(dtype, DType::Ptr { addrspace: AddrSpace::Reg, .. } | DType::Vector { .. })
1426 }
1427
1428 fn find_underlying_load(uop: &Arc<UOp>) -> Option<Arc<UOp>> {
1430 match uop.op() {
1431 Op::Gep { vector, .. } => find_underlying_load(vector),
1432 Op::Load { .. } => Some(uop.clone()),
1433 _ => None,
1434 }
1435 }
1436
1437 let is_amx_tc_acc = match ls.op() {
1438 Op::Store { value, .. } => {
1439 if let Some(load) = find_underlying_load(value) {
1442 if let Op::Load { index, .. } = load.op() {
1443 if let Op::Index { buffer, .. } = index.op() {
1444 let buf_dtype = buffer.dtype();
1445 is_amx && is_amx_tc_reg_ptr(&buf_dtype, sz)
1446 } else {
1447 false
1448 }
1449 } else {
1450 false
1451 }
1452 } else {
1453 false
1454 }
1455 }
1456 Op::Load { .. } => is_amx && is_amx_tc_reg_ptr(&buf_dtype, sz),
1457 _ => false,
1458 };
1459
1460 let no_fold = (!buf_dtype.base().is_float() && !matches!(buf_dtype, DType::Image { .. }))
1463 || (matches!(buf_dtype, DType::Ptr { addrspace: AddrSpace::Reg, .. }) && !is_amx_tc_acc);
1464
1465 let mut lengths = if no_fold {
1466 vec![1]
1467 } else if matches!(buf_dtype, DType::Image { .. }) {
1468 vec![4, 1]
1469 } else if is_amx {
1470 vec![16, 8, 4, 2, 1] } else {
1472 vec![4, 2, 1]
1476 };
1477
1478 if !is_amx_tc_acc && let Some(offset) = indices.first() {
1483 lengths.retain(|&len| offset_divides_evenly(offset, len));
1484 }
1485
1486 let scalar_dtype = buf_dtype.scalar_dtype();
1488 let mut ret = Vec::new();
1489 let mut pos = 0usize;
1490
1491 while pos < sz {
1492 for &fold_len in &lengths {
1493 if pos + fold_len > sz {
1494 continue;
1495 }
1496 let lidx = if pos == 0 { idx.clone() } else { offset_index(idx, pos as i64) };
1497 let lidx = if fold_len > 1 { lidx.cast(make_vec_ptr_dtype(buf, fold_len)) } else { lidx };
1498
1499 match ls.op() {
1500 Op::Store { value, ranges, .. } => {
1501 ret.push(lidx.store_with_ranges(value.gep((pos..pos + fold_len).collect()), ranges.clone()));
1502 }
1503 Op::Load { buffer, .. } => {
1504 ret.push(UOp::load().buffer(buffer.clone()).index(lidx).dtype(scalar_dtype.vec(fold_len)).call());
1505 }
1506 _ => return None,
1507 }
1508 pos += fold_len;
1509 break;
1510 }
1511 }
1512
1513 if ret.len() <= 1 {
1514 return None;
1515 }
1516
1517 match ls.op() {
1518 Op::Load { .. } => Some(UOp::cat().sources(ret).dtype(scalar_dtype.vec(sz)).call()),
1519 Op::Store { .. } => Some(UOp::group(ret.into_iter().collect())),
1520 _ => None,
1521 }
1522}
1523
1524fn offset_divides_evenly(offset: &Arc<UOp>, len: usize) -> bool {
1527 if len == 0 {
1529 return false;
1530 }
1531 if len == 1 {
1533 return true;
1534 }
1535 let v = len as i64;
1536
1537 match offset.op() {
1538 Op::Const(cv) => matches!(cv.0, ConstValue::Int(n) if n % v == 0),
1540
1541 Op::VConst { values } => values.iter().all(|val| matches!(val, ConstValue::Int(n) if n % v == 0)),
1543
1544 Op::Binary(BinaryOp::Add, left, right) => offset_divides_evenly(left, len) && offset_divides_evenly(right, len),
1546
1547 Op::Binary(BinaryOp::Mul, left, right) => {
1549 let check_const =
1550 |c: &Arc<UOp>| matches!(c.op(), Op::Const(cv) if matches!(cv.0, ConstValue::Int(n) if n % v == 0));
1551 check_const(left)
1552 || check_const(right)
1553 || offset_divides_evenly(left, len)
1554 || offset_divides_evenly(right, len)
1555 }
1556
1557 _ => false,
1558 }
1559}
1560
1561fn offset_index(idx: &Arc<UOp>, offset: i64) -> Arc<UOp> {
1562 let Op::Index { buffer, indices, gate } = idx.op() else {
1563 return idx.clone();
1564 };
1565 let new_indices: SmallVec<[Arc<UOp>; 4]> = indices
1566 .iter()
1567 .enumerate()
1568 .map(|(i, index_expr)| if i == 0 { index_expr.add(&index_expr.const_like(offset)) } else { index_expr.clone() })
1569 .collect();
1570
1571 UOp::index()
1572 .buffer(buffer.clone())
1573 .indices(new_indices)
1574 .maybe_gate(gate.clone())
1575 .ptr(true)
1576 .call()
1577 .expect("ICE: unable to create index")
1578}
1579
1580fn image_fixup(ls: &Arc<UOp>) -> Option<Arc<UOp>> {
1594 let (index, is_load) = match ls.op() {
1597 Op::Load { index, .. } => (index, true),
1598 Op::Store { index, .. } => (index, false),
1599 _ => return None,
1600 };
1601
1602 if let Op::Cast { src: inner_idx, dtype: cast_dtype } = index.op()
1604 && let Op::Index { buffer: img_buf, indices, gate } = inner_idx.op()
1605 {
1606 let DType::Image { shape, .. } = img_buf.dtype() else { return None };
1608
1609 if cast_dtype.vcount() != 4 {
1611 return None;
1612 }
1613
1614 let lin_idx = indices.first()?;
1616 let x = lin_idx.get_idx();
1617 let valid = lin_idx.get_valid();
1618
1619 let width = shape.get(1).copied().unwrap_or(1) as i64;
1621
1622 let four = UOp::index_const(4);
1624 let width_const = UOp::index_const(width);
1625 let stride = UOp::index_const(4 * width);
1626
1627 let x_coord = x.idiv(&four).mod_(&width_const);
1628 let y_coord = x.idiv(&stride);
1629
1630 let oidx = UOp::vectorize(smallvec::smallvec![x_coord, y_coord]);
1632
1633 let new_idx_expr = if matches!(valid.op(), Op::Const(cv) if cv.0 == ConstValue::Bool(true)) {
1635 oidx
1636 } else {
1637 oidx.valid(valid)
1638 };
1639
1640 let new_idx = if matches!(inner_idx.dtype(), DType::Ptr { .. }) {
1643 UOp::index()
1644 .buffer(img_buf.clone())
1645 .indices(vec![new_idx_expr])
1646 .maybe_gate(gate.clone())
1647 .ptr(true)
1648 .call()
1649 .ok()?
1650 } else {
1651 UOp::index()
1652 .buffer(img_buf.clone())
1653 .indices(vec![new_idx_expr])
1654 .maybe_gate(gate.clone())
1655 .dtype(inner_idx.dtype())
1656 .call()
1657 .ok()?
1658 };
1659
1660 return Some(ls.replace().src(vec![new_idx]).call());
1662 }
1663
1664 if let Op::Index { buffer: img_buf, indices, gate } = index.op() {
1666 let DType::Image { shape, .. } = img_buf.dtype() else { return None };
1667
1668 let lin_idx = indices.first()?;
1670 let x = lin_idx.get_idx();
1671
1672 if x.dtype().vcount() == 2 {
1674 return None; }
1676
1677 if !is_load {
1679 tracing::warn!("image_fixup: STORE with unfoldable image not supported");
1680 return None;
1681 }
1682
1683 let valid = lin_idx.get_valid();
1684
1685 let width = shape.get(1).copied().unwrap_or(1) as i64;
1687
1688 let four = UOp::index_const(4);
1690 let width_const = UOp::index_const(width);
1691 let stride = UOp::index_const(4 * width);
1692
1693 let x_coord = x.idiv(&four).mod_(&width_const);
1694 let y_coord = x.idiv(&stride);
1695
1696 let oidx = UOp::vectorize(smallvec::smallvec![x_coord, y_coord]);
1697
1698 let new_idx_expr = if matches!(valid.op(), Op::Const(cv) if cv.0 == ConstValue::Bool(true)) {
1699 oidx
1700 } else {
1701 oidx.valid(valid)
1702 };
1703
1704 let new_idx = if matches!(index.dtype(), DType::Ptr { .. }) {
1706 UOp::index()
1707 .buffer(img_buf.clone())
1708 .indices(vec![new_idx_expr])
1709 .maybe_gate(gate.clone())
1710 .ptr(true)
1711 .call()
1712 .ok()?
1713 } else {
1714 UOp::index()
1715 .buffer(img_buf.clone())
1716 .indices(vec![new_idx_expr])
1717 .maybe_gate(gate.clone())
1718 .dtype(index.dtype())
1719 .call()
1720 .ok()?
1721 };
1722
1723 let vec4_dtype = ls.dtype().vec(4);
1726 let vec_load = UOp::load().buffer(ls.load_buffer()?).index(new_idx).dtype(vec4_dtype).call();
1727
1728 let x_mod_4 = x.mod_(&four);
1730 let nan = ls.const_like(ConstValue::Float(f64::NAN));
1731
1732 let result = (0..4).rev().fold(nan, |ret, i| {
1733 let i_const = UOp::index_const(i);
1734 let not_eq = x_mod_4.ne(&i_const);
1735 let gep_i = vec_load.gep(vec![i as usize]);
1736 UOp::try_where(not_eq, ret, gep_i).expect("WHERE")
1737 });
1738
1739 return Some(result);
1740 }
1741
1742 None
1743}
1744
1745use crate::symbolic::dce::reduce_identity;
1750
1751pub fn pm_reduce() -> TypedPatternMatcher<ReduceContext> {
1779 crate::patterns! {
1780 @context ReduceContext;
1781
1782 red @ Reduce(_, ..) => {
1784 reduce_to_acc(red, ctx)
1785 },
1786
1787 Sink { sources: _sources } => {
1789 ctx.merge_reduce_ends(_sources)
1790 },
1791 }
1792}
1793
1794fn horizontal_reduce(inp: &Arc<UOp>, out_dtype: &DType) -> Vec<Arc<UOp>> {
1796 if inp.dtype() == *out_dtype {
1797 return vec![inp.clone()];
1798 }
1799 let inp_vcount = inp.dtype().vcount();
1800 let out_vcount = out_dtype.vcount();
1801 assert!(
1802 inp_vcount.is_multiple_of(out_vcount),
1803 "horizontal mismatch: inp.dtype={:?} (vcount={}), out_dtype={:?} (vcount={})",
1804 inp.dtype(),
1805 inp_vcount,
1806 out_dtype,
1807 out_vcount
1808 );
1809 let horizontal_amount = inp_vcount / out_vcount;
1810 (0..horizontal_amount).map(|i| inp.gep((i..inp_vcount).step_by(horizontal_amount).collect())).collect()
1811}
1812
1813fn reduce_to_acc(red: &Arc<UOp>, ctx: &mut ReduceContext) -> Option<Arc<UOp>> {
1815 let Op::Reduce { src: inp, ranges: reduce_range, reduce_op } = red.op() else { return None };
1816
1817 let lst = horizontal_reduce(inp, &red.dtype());
1818 debug_assert!(lst.iter().all(|x| x.dtype() == red.dtype()), "horizontal reduction mismatch");
1819
1820 if reduce_range.is_empty() {
1822 return lst.into_iter().reduce(|a, b| apply_reduce_binary(*reduce_op, a, b, &red.dtype()));
1823 }
1824
1825 let topo = inp.toposort();
1827 let ended: HashSet<u64> = topo
1828 .iter()
1829 .filter_map(|n| if let Op::End { ranges, .. } = n.op() { Some(ranges.iter().map(|r| r.id)) } else { None })
1830 .flatten()
1831 .collect();
1832 let reduce_ids: HashSet<u64> = reduce_range.iter().map(|r| r.id).collect();
1833 let input_ranges: SmallVec<[Arc<UOp>; 4]> = topo
1834 .iter()
1835 .filter(|n| matches!(n.op(), Op::Range { .. }) && !reduce_ids.contains(&n.id) && !ended.contains(&n.id))
1836 .cloned()
1837 .collect();
1838
1839 let identity = reduce_identity(*reduce_op, red.dtype());
1841 let acc = UOp::define_reg_typed(1, red.dtype());
1842 let zero = UOp::index_const(0);
1843 let make_idx = |buf: Arc<UOp>| UOp::index().buffer(buf).indices(vec![zero.clone()]).call().expect("index");
1844
1845 let acc_init = make_idx(acc.after(input_ranges)).store_value(identity);
1847
1848 let mut loop_deps: SmallVec<[Arc<UOp>; 4]> = smallvec::smallvec![acc_init];
1850 loop_deps.extend(reduce_range.iter().cloned());
1851 let acc_loop = make_idx(acc.after(loop_deps));
1852 let lst_with_acc = std::iter::once(acc_loop).chain(lst);
1853
1854 let ret = lst_with_acc.reduce(|a, b| apply_reduce_binary(*reduce_op, a, b, &red.dtype()))?;
1856
1857 let store_end = make_idx(acc.clone()).store_value(ret).end(reduce_range.clone());
1859 ctx.register_end(&store_end);
1860 Some(make_idx(acc.after(smallvec::smallvec![store_end])))
1861}
1862
1863fn apply_reduce_binary(reduce_op: ReduceOp, a: Arc<UOp>, b: Arc<UOp>, dtype: &DType) -> Arc<UOp> {
1865 debug_assert!(a.dtype() == b.dtype(), "reduce operand dtype mismatch");
1866 match reduce_op {
1867 ReduceOp::Add => UOp::new(Op::Binary(BinaryOp::Add, a, b), dtype.clone()),
1868 ReduceOp::Mul => UOp::new(Op::Binary(BinaryOp::Mul, a, b), dtype.clone()),
1869 ReduceOp::Max => UOp::new(Op::Binary(BinaryOp::Max, a, b), dtype.clone()),
1870 ReduceOp::Min => {
1871 let cond = UOp::new(Op::Binary(BinaryOp::Lt, a.clone(), b.clone()), DType::Bool.vec(dtype.vcount()));
1872 UOp::try_where(cond, a, b).expect("WHERE")
1873 }
1874 }
1875}