use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use std::sync::LazyLock;
use itertools::Itertools;
use morok_dtype::{AddrSpace, DType, ScalarDType};
use morok_ir::{BinaryOp, ConstValue, Op, ReduceOp, TernaryOp, UOp, UOpKey, UnaryOp, WmmaMetadata};
use crate::TypedPatternMatcher;
use smallvec::SmallVec;
#[derive(Debug, Default)]
pub struct ReduceContext {
range_to_ends: HashMap<SmallVec<[u64; 4]>, Vec<Arc<UOp>>>,
}
impl ReduceContext {
pub fn register_end(&mut self, end: &Arc<UOp>) {
if let Op::End { ranges, .. } = end.op() {
let mut key: SmallVec<[u64; 4]> = ranges.iter().map(|r| r.id).collect();
key.sort_unstable();
self.range_to_ends.entry(key).or_default().push(end.clone());
}
}
pub fn merge_reduce_ends(&mut self, sources: &SmallVec<[Arc<UOp>; 4]>) -> Option<Arc<UOp>> {
#[allow(clippy::mutable_key_type)]
let subs = build_end_merge_subs(&self.range_to_ends);
self.range_to_ends.clear();
if subs.is_empty() {
return None;
}
Some(UOp::sink(sources.to_vec()).substitute(&subs))
}
}
#[allow(clippy::mutable_key_type)]
fn build_end_merge_subs(range_to_ends: &HashMap<SmallVec<[u64; 4]>, Vec<Arc<UOp>>>) -> HashMap<UOpKey, Arc<UOp>> {
let mut subs = HashMap::new();
for ends in range_to_ends.values() {
if ends.len() <= 1 {
continue;
}
let computations: Vec<Arc<UOp>> = ends
.iter()
.map(|e| match e.op() {
Op::End { computation, .. } => computation.clone(),
_ => unreachable!(),
})
.collect();
let ranges = match ends[0].op() {
Op::End { ranges, .. } => ranges.clone(),
_ => unreachable!(),
};
let merged = UOp::group(computations).end(ranges);
for end in ends {
subs.insert(UOpKey(end.clone()), merged.clone());
}
}
subs
}
pub fn merge_sibling_ends(sink: &Arc<UOp>) -> Arc<UOp> {
let Op::Sink { sources } = sink.op() else { return sink.clone() };
let mut range_to_ends: HashMap<SmallVec<[u64; 4]>, Vec<Arc<UOp>>> = HashMap::new();
for node in sink.toposort() {
if let Op::End { ranges, .. } = node.op() {
let mut key: SmallVec<[u64; 4]> = ranges.iter().map(|r| r.id).collect();
key.sort_unstable();
range_to_ends.entry(key).or_default().push(node.clone());
}
}
#[allow(clippy::mutable_key_type)]
let subs = build_end_merge_subs(&range_to_ends);
if subs.is_empty() {
return sink.clone();
}
UOp::sink(sources.to_vec()).substitute(&subs)
}
use crate::rewrite::graph_rewrite;
use crate::symbolic::patterns::sym;
pub fn devectorize(ast: &Arc<UOp>) -> Arc<UOp> {
static COMBINED: LazyLock<TypedPatternMatcher> = LazyLock::new(|| {
sym()
+ devectorize_patterns()
+ load_store_folding_patterns()
+ correct_load_store_patterns()
+ load_store_indexing_patterns()
});
graph_rewrite(&*COMBINED, ast.clone(), &mut ())
}
pub fn bool_storage_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Store { index, value, ranges } if value.dtype().base().is_bool() => {
let uint8_dtype = value.dtype().with_base(ScalarDType::UInt8);
Some(index.store_with_ranges(value.cast(uint8_dtype), ranges.clone()))
},
load @ Load { buffer, index } if load.dtype().base().is_bool() => {
let uint8_dtype = load.dtype().with_base(ScalarDType::UInt8);
let uint8_load = UOp::load().buffer(buffer.clone()).index(index.clone()).dtype(uint8_dtype).call();
Some(uint8_load.cast(load.dtype()))
},
BitCast { src, dtype } if src.dtype().base().is_bool() || dtype.base().is_bool() => {
Some(src.cast(dtype.clone()))
},
}
}
#[derive(Debug, Clone)]
pub struct Fp8DecompCtx {
pub from: ScalarDType,
pub to: ScalarDType,
}
fn rne(v: &Arc<UOp>, s: u32) -> Arc<UOp> {
let one = v.const_like(1);
let shifted = v.shr(&v.const_like(s));
let half_bit = v.shr(&v.const_like(s - 1)).and_(&one);
let remainder_mask = v.const_like((1i64 << (s - 1)) - 1);
let has_remainder = v.and_(&remainder_mask).ne(&v.const_like(0)).cast(v.dtype());
let lsb = shifted.and_(&one);
let round_up = half_bit.and_(&has_remainder.or_(&lsb));
shifted.try_add(&round_up).expect("rne: add failed")
}
fn f2f(v: &Arc<UOp>, fr: ScalarDType, to: ScalarDType) -> Arc<UOp> {
let (fe, fm) = fr.finfo();
let (te, tm) = to.finfo();
let fs = fr.bitsize();
let ts = to.bitsize();
let fb = fr.exponent_bias() as i64;
let tb = to.exponent_bias() as i64;
let fr_uint = DType::Scalar(fr.float_to_uint());
let to_uint = DType::Scalar(to.float_to_uint());
if fe <= te && fm < tm {
let sign_mask = v.const_like(1i64 << (fs - 1));
let sign = v.and_(&sign_mask).cast(to_uint.clone()).shl(&v.const_like(ts - fs).cast(to_uint.clone()));
let nosign_mask = v.const_like((1i64 << (fs - 1)) - 1);
let nosign = v.and_(&nosign_mask).cast(to_uint.clone());
let exp = nosign.shr(&nosign.const_like(fm));
let norm = nosign
.shl(&nosign.const_like(tm - fm))
.try_add(&nosign.const_like((tb - fb) << tm))
.expect("f2f: add failed");
let nan_val = nosign.shl(&nosign.const_like(tm - fm)).or_(&nosign.const_like(((1i64 << te) - 1) << tm));
let is_nan = if fr == ScalarDType::FP8E4M3 {
nosign.eq(&nosign.const_like((1i64 << (fm + fe)) - 1))
} else {
exp.eq(&exp.const_like((1i64 << fe) - 1))
};
let zero = nosign.const_like(0);
let exp_is_zero = exp.eq(&zero);
let inner = UOp::try_where(is_nan, nan_val, norm).expect("f2f: where failed");
let result = UOp::try_where(exp_is_zero, zero, inner).expect("f2f: where failed");
sign.or_(&result).bitcast(DType::Scalar(to))
} else if fe >= te && fm > tm {
let clamped = f2f_clamp(&v.bitcast(DType::Scalar(fr)), to);
let v = clamped.bitcast(fr_uint);
let sign = v.shr(&v.const_like(fs - ts)).and_(&v.const_like(1i64 << (ts - 1)));
let nosign_mask = v.const_like((1i64 << (fs - 1)) - 1);
let nosign = v.and_(&nosign_mask);
let norm = rne(&nosign, fm - tm)
.try_sub(&nosign.const_like((fb - tb) << tm))
.expect("f2f: sub failed")
.cast(to_uint.clone());
let exp_field = nosign.shr(&nosign.const_like(fm)).and_(&nosign.const_like((1i64 << fe) - 1));
let underflow = exp_field.lt(&exp_field.const_like(1 + fb - tb));
let nan_mantissa = if to == ScalarDType::FP8E4M3 {
sign.const_like((1i64 << tm) - 1).cast(to_uint.clone())
} else {
nosign.shr(&nosign.const_like(fm - tm)).and_(&nosign.const_like((1i64 << tm) - 1)).cast(to_uint.clone())
};
let nan_exp = sign.const_like(((1i64 << te) - 1) << tm).cast(to_uint.clone());
let nan = sign.cast(to_uint.clone()).or_(&nan_mantissa).or_(&nan_exp);
let is_nan = exp_field.eq(&exp_field.const_like((1i64 << fe) - 1));
let zero = sign.const_like(0).cast(to_uint.clone());
let normal = sign.cast(to_uint.clone()).or_(&UOp::try_where(underflow, zero, norm).expect("f2f: where failed"));
UOp::try_where(is_nan, nan, normal).expect("f2f: where failed")
} else {
panic!("f2f: unsupported conversion {fr:?} -> {to:?}")
}
}
fn f2f_clamp(val: &Arc<UOp>, dt: ScalarDType) -> Arc<UOp> {
let (e, m) = dt.finfo();
let (max_exp, max_man): (i64, i64) =
if dt == ScalarDType::FP8E4M3 { ((1 << e) - 1, (1 << m) - 2) } else { ((1 << e) - 2, (1 << m) - 1) };
let mx_f64 =
f64::powi(2.0, (max_exp - dt.exponent_bias() as i64) as i32) * (1.0 + max_man as f64 / (1i64 << m) as f64);
let mx = val.const_like(mx_f64);
let neg_mx = val.const_like(-mx_f64);
let sat = if dt.is_fp8() { mx.clone() } else { val.const_like(f64::INFINITY) };
let neg_sat = if dt.is_fp8() { neg_mx.clone() } else { val.const_like(f64::NEG_INFINITY) };
let is_nan = val.ne(val);
let below = val.lt(&neg_mx);
let above = mx.lt(val);
let clamped_above = UOp::try_where(above, sat, val.clone()).expect("f2f_clamp: where failed");
let clamped = UOp::try_where(below, neg_sat, clamped_above).expect("f2f_clamp: where failed");
UOp::try_where(is_nan, val.clone(), clamped).expect("f2f_clamp: where failed")
}
pub fn pm_float_decomp_store() -> crate::TypedPatternMatcher<Fp8DecompCtx> {
crate::patterns! {
@context Fp8DecompCtx;
Store { index, value, ranges }
if index.dtype().base() == ctx.from
=> {
let target_float = DType::Scalar(ctx.to);
let target_uint = DType::Scalar(ctx.to.float_to_uint());
let float_val = value.cast(target_float);
let result = f2f(&float_val.bitcast(target_uint), ctx.to, ctx.from);
let uint8_ptr = index.dtype().with_ptr_base(DType::Scalar(ctx.from.float_to_uint()))?;
let new_index = index.with_dtype(uint8_ptr);
Some(new_index.store_with_ranges(result, ranges.clone()))
},
}
}
pub fn pm_float_decomp() -> crate::TypedPatternMatcher<Fp8DecompCtx> {
crate::patterns! {
@context Fp8DecompCtx;
x if matches!(x.op(), Op::Param { device: None, .. } | Op::DefineLocal(_) | Op::Index { .. })
&& x.dtype().base() == ctx.from
=> {
let uint8_ptr = x.dtype().with_ptr_base(DType::Scalar(ctx.from.float_to_uint()))?;
Some(x.with_dtype(uint8_ptr))
},
load @ Load { buffer, index } if load.dtype().base() == ctx.from => {
let uint_dtype = DType::Scalar(ctx.from.float_to_uint());
let uint_load = UOp::load().buffer(buffer.clone()).index(index.clone())
.dtype(uint_dtype).call();
Some(f2f(&uint_load, ctx.from, ctx.to))
},
x @ Cast { src: val, .. } if x.dtype().base() == ctx.from => {
let target = DType::Scalar(ctx.to);
let target_uint = DType::Scalar(ctx.to.float_to_uint());
let float_val = val.cast(target);
let fp8_bytes = f2f(&float_val.bitcast(target_uint.clone()), ctx.to, ctx.from);
Some(f2f(&fp8_bytes, ctx.from, ctx.to))
},
x if !matches!(x.op(), Op::BitCast { .. })
&& x.dtype().is_float()
&& x.dtype().base() == ctx.from
=> {
let target_dtype = DType::Scalar(ctx.to);
let new_dtype = if x.dtype().vcount() > 1 {
target_dtype.vec(x.dtype().vcount())
} else {
target_dtype.clone()
};
let new_sources: Vec<Arc<UOp>> = x.op().sources().iter().map(|s| {
if s.dtype().base() == ctx.from {
s.cast(target_dtype.clone())
} else {
s.clone()
}
}).collect();
Some(x.with_sources(new_sources).with_dtype(new_dtype))
},
}
}
pub fn pm_render() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
c @ Const(_) if c.dtype().vcount() > 1 => |c| {
let vcount = c.dtype().vcount();
let Op::Const(cv) = c.op() else { return None };
let scalar_const = UOp::const_(c.dtype().scalar_dtype(), cv.0);
let elements: SmallVec<[Arc<UOp>; 4]> = (0..vcount).map(|_| scalar_const.clone()).collect();
Some(UOp::vectorize(elements))
},
vc @ VConst { values } => |vc, values| {
let scalar_dtype = vc.dtype().scalar_dtype();
let elements: SmallVec<[Arc<UOp>; 4]> = values.iter()
.map(|v| UOp::const_(scalar_dtype.clone(), *v))
.collect();
Some(UOp::vectorize(elements))
},
Cat { sources } if sources.len() == 1 => Some(sources[0].clone()),
Cat { sources } => {
let elements: SmallVec<[Arc<UOp>; 4]> = sources.iter()
.flat_map(|src| {
let n = src.dtype().vcount();
(0..n).map(move |i| if n == 1 { src.clone() } else { src.gep(vec![i]) })
})
.collect();
Some(UOp::vectorize(elements))
},
Gep { vector, indices } if vector.dtype().vcount() == 1 && indices.len() == 1 && indices[0] == 0
~> |vector| Arc::clone(vector),
Gep { vector, indices } if is_identity_gep(vector, indices) => Some(vector.clone()),
Gep { vector: Vectorize { elements }, indices } => {
let extracted: SmallVec<[Arc<UOp>; 4]> = indices.iter()
.filter_map(|&i| elements.get(i).cloned())
.collect();
if extracted.len() != indices.len() { return None; }
Some(if extracted.len() == 1 { extracted[0].clone() } else { UOp::vectorize(extracted) })
},
Gep { vector, indices } if indices.len() > 1 => |vector, indices| {
let geps: SmallVec<[Arc<UOp>; 4]> = indices.iter()
.map(|&i| vector.gep(vec![i]))
.collect();
Some(UOp::vectorize(geps))
},
Vectorize { elements } if elements.len() == 1 => Some(elements[0].clone()),
PtrCat { sources } if sources.len() == 1 => Some(sources[0].clone()),
load @ Load { index, alt: None, .. } if has_gate(index) => |load, index| {
let alt_value = load.const_like(ConstValue::Int(0));
Some(UOp::load().buffer(load.load_buffer()?).index(index.clone()).alt(alt_value).dtype(load.dtype()).call())
},
Where(cond, load @ Load { index, .. }, alt)
if index_has_gate_matching(index, cond)
=> |cond, load, index, alt| {
let casted_alt = cast_alt_avoiding_roundtrip(alt, &load.dtype());
let new_load = UOp::load()
.buffer(load.load_buffer()?)
.index(index.clone())
.alt(casted_alt)
.dtype(load.dtype())
.call();
Some(new_load.cast(alt.dtype()))
},
Where(cond, alt, load @ Load { index, .. })
if index_has_inverted_gate_matching(index, cond)
=> |cond, alt, load, index| {
let casted_alt = cast_alt_avoiding_roundtrip(alt, &load.dtype());
let new_load = UOp::load()
.buffer(load.load_buffer()?)
.index(index.clone())
.alt(casted_alt)
.dtype(load.dtype())
.call();
Some(new_load.cast(alt.dtype()))
},
}
}
fn cast_alt_avoiding_roundtrip(alt: &Arc<UOp>, load_dtype: &DType) -> Arc<UOp> {
if let Op::Cast { src: inner, .. } = alt.op()
&& inner.dtype() == *load_dtype
{
return inner.clone();
}
alt.cast(load_dtype.clone())
}
fn is_identity_gep(vector: &Arc<UOp>, indices: &[usize]) -> bool {
let vcount = vector.dtype().vcount();
indices.len() == vcount && indices.iter().enumerate().all(|(i, &j)| i == j)
}
fn has_gate(index: &Arc<UOp>) -> bool {
match index.op() {
Op::Index { gate: Some(_), .. } => true,
Op::Cast { src, .. } => has_gate(src),
_ => false,
}
}
fn index_has_gate_matching(index: &Arc<UOp>, cond: &Arc<UOp>) -> bool {
match index.op() {
Op::Index { gate: Some(g), .. } => Arc::ptr_eq(g, cond),
Op::Cast { src, .. } => index_has_gate_matching(src, cond),
_ => false,
}
}
fn index_has_inverted_gate_matching(index: &Arc<UOp>, cond: &Arc<UOp>) -> bool {
match index.op() {
Op::Index { gate: Some(g), .. } => is_negation_of(g, cond),
Op::Cast { src, .. } => index_has_inverted_gate_matching(src, cond),
_ => false,
}
}
fn is_negation_of(gate: &Arc<UOp>, cond: &Arc<UOp>) -> bool {
if let Op::Unary(UnaryOp::Not, inner) = gate.op()
&& Arc::ptr_eq(inner, cond)
{
return true;
}
if let Op::Binary(BinaryOp::Lt, gate_lhs, gate_rhs) = gate.op()
&& let Op::Binary(BinaryOp::Lt, cond_lhs, cond_rhs) = cond.op()
{
if Arc::ptr_eq(gate_rhs, cond_lhs)
&& let (Op::Const(gate_cv), Op::Const(cond_cv)) = (gate_lhs.op(), cond_rhs.op())
&& is_const_minus_one(&cond_cv.0, &gate_cv.0)
{
return true;
}
if Arc::ptr_eq(gate_lhs, cond_rhs)
&& let (Op::Const(cond_cv), Op::Const(gate_cv)) = (cond_lhs.op(), gate_rhs.op())
&& is_const_minus_one(&cond_cv.0, &gate_cv.0)
{
return true;
}
}
false
}
fn is_const_minus_one(a: &ConstValue, b: &ConstValue) -> bool {
match (a, b) {
(ConstValue::Int(av), ConstValue::Int(bv)) => av.checked_sub(1) == Some(*bv),
(ConstValue::UInt(av), ConstValue::UInt(bv)) => av.checked_sub(1) == Some(*bv),
_ => false,
}
}
fn devectorize_alu(alu: &Arc<UOp>) -> Option<Arc<UOp>> {
let vcount = alu.dtype().vcount();
if vcount <= 1 {
return None;
}
if let Op::Ternary(TernaryOp::Where, _, _, f) = alu.op()
&& UOp::is_invalid_marker(f)
{
return None;
}
let scalar_dtype = alu.dtype().scalar_dtype();
let sources = alu.op().sources();
let elements: SmallVec<[Arc<UOp>; 4]> = (0..vcount)
.map(|i| {
let new_sources: Vec<Arc<UOp>> =
sources.iter().map(|s| if s.dtype().vcount() > 1 { s.gep(vec![i]) } else { s.clone() }).collect();
match alu.op() {
Op::Cast { .. } => new_sources[0].cast(scalar_dtype.clone()),
Op::BitCast { .. } => new_sources[0].bitcast(scalar_dtype.clone()),
_ => alu.replace().dtype(scalar_dtype.clone()).src(new_sources).call(),
}
})
.collect();
Some(UOp::vectorize(elements))
}
#[allow(unused_variables)]
pub fn no_vectorized_alu() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
for op in binary [*] {
alu @ op(_, _) if alu.dtype().vcount() > 1 => devectorize_alu(alu),
},
for op in unary [*] {
alu @ op(_) if alu.dtype().vcount() > 1 => devectorize_alu(alu),
},
for op in ternary [*] {
alu @ op(_, _, _) if alu.dtype().vcount() > 1 => devectorize_alu(alu),
},
alu @ Cast { src: _, .. } if alu.dtype().vcount() > 1 => devectorize_alu(alu),
alu @ BitCast { src: _, .. } if alu.dtype().vcount() > 1 => devectorize_alu(alu),
}
}
pub fn devectorize_patterns() -> &'static TypedPatternMatcher {
use std::sync::LazyLock;
static CACHED: LazyLock<TypedPatternMatcher> = LazyLock::new(|| {
cast_after_pattern() + no_vectorized_alu() + no_vectorized_wmma() + devectorize_buf_and_index_patterns()
});
&CACHED
}
fn no_vectorized_wmma() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
wmma @ Wmma { a, b, c, metadata } if wmma.dtype().vcount() > wmma_expected_size(metadata)
=> devectorize_wmma(wmma, a, b, c, metadata),
}
}
fn wmma_expected_size(metadata: &WmmaMetadata) -> usize {
metadata.upcast_axes.c.iter().map(|(_, size)| size).product::<usize>().max(1)
}
fn devectorize_wmma(
wmma: &Arc<UOp>,
a: &Arc<UOp>,
b: &Arc<UOp>,
c: &Arc<UOp>,
metadata: &WmmaMetadata,
) -> Option<Arc<UOp>> {
let out_sz = wmma_expected_size(metadata);
if wmma.dtype().vcount() == out_sz {
return None;
}
let sources: [&Arc<UOp>; 3] = [a, b, c];
let tsrcs: Vec<Vec<Arc<UOp>>> = sources
.iter()
.enumerate()
.map(|(i, src)| {
let ssz = metadata.upcast_axes.source_size(i);
let n = src.dtype().vcount();
(0..n).step_by(ssz).map(|g| src.gep((g..g + ssz.min(n - g)).collect())).collect()
})
.collect();
let num_groups = tsrcs[0].len();
if tsrcs.iter().any(|t| t.len() != num_groups) {
tracing::warn!("WMMA devectorization: mismatched source group counts");
return None;
}
let wmma_ex: SmallVec<[Arc<UOp>; 4]> = (0..num_groups)
.flat_map(|g| {
let w = UOp::wmma(tsrcs[0][g].clone(), tsrcs[1][g].clone(), tsrcs[2][g].clone(), metadata.clone());
(0..out_sz).map(move |i| w.gep(vec![i]))
})
.collect();
Some(UOp::vectorize(wmma_ex))
}
fn cast_after_pattern() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
After { passthrough: Cast { src, dtype }, deps }
=> |src, dtype, deps| {
let new_after = src.after(deps.clone());
Some(new_after.cast(dtype.clone()))
},
}
}
fn devectorize_buf_and_index_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
def if matches!(def.op(), Op::DefineLocal(_) | Op::DefineReg { .. })
&& def.ptrdtype().is_some_and(|(base, _, _)| base.vcount() > 1)
=> no_vectorized_buf(def),
Index { buffer: Cast { src: buf, dtype: cast_dtype }, indices, gate }
if is_vectorized_local_reg_cast(buf, cast_dtype)
=> no_vectorized_index(buf, indices, gate, cast_dtype),
Index { buffer: Vectorize { elements }, indices, gate }
if is_vectorized_broadcast_cast(elements)
=> {
let first = elements.first()?;
let Op::Cast { src: buf, dtype: DType::Ptr { base, .. } } = first.op() else { return None };
let idx = indices.first()?;
no_vectorized_index_precnt(buf, idx, gate, base.vcount(), &vec![0; elements.len()])
},
Index { buffer: Gep { vector: Cast { src: buf, dtype: cast_dtype }, indices: gep_indices }, indices, gate }
if is_vectorized_local_reg_cast(buf, cast_dtype)
=> {
let DType::Ptr { base, .. } = cast_dtype else { return None };
let idx = indices.first()?;
no_vectorized_index_precnt(buf, idx, gate, base.vcount(), gep_indices)
},
}
}
fn is_vectorized_local_reg_cast(buf: &Arc<UOp>, cast_dtype: &DType) -> bool {
matches!(cast_dtype, DType::Ptr { base, .. } if base.vcount() > 1) && is_define_local_or_reg_or_after(buf)
}
fn is_vectorized_broadcast_cast(elements: &SmallVec<[Arc<UOp>; 4]>) -> bool {
elements.first().is_some_and(|f| {
matches!(f.op(), Op::Cast { dtype: DType::Ptr { base, .. }, src }
if base.vcount() > 1 && is_define_local_or_reg_or_after(src))
})
}
fn is_define_local_or_reg_or_after(uop: &Arc<UOp>) -> bool {
matches!(uop.unwrap_after().op(), Op::DefineLocal(_) | Op::DefineReg { .. })
}
fn no_vectorized_buf(buf: &Arc<UOp>) -> Option<Arc<UOp>> {
let (base, addrspace, size) = buf.ptrdtype()?;
let vcount = base.vcount();
if vcount <= 1 {
return None;
}
let scalar_base = base.base();
let new_size = size.map(|s| s * vcount);
let scalar_ptr_dtype =
DType::Ptr { base: Box::new(DType::Scalar(scalar_base)), addrspace, size: new_size, vcount: 1 };
let scalar_def = buf.with_dtype(scalar_ptr_dtype);
Some(scalar_def.cast(buf.dtype()))
}
fn no_vectorized_index(
buf: &Arc<UOp>,
indices: &SmallVec<[Arc<UOp>; 4]>,
gate: &Option<Arc<UOp>>,
cast_dtype: &DType,
) -> Option<Arc<UOp>> {
let idx = indices.first()?;
let DType::Ptr { base, .. } = cast_dtype else { return None };
let cnt = base.vcount();
if cnt <= 1 {
return None;
}
let idx_vcount = idx.dtype().vcount();
let total = cnt * idx_vcount;
let buf_broadcast = buf.broadcast(total);
let final_idx = if idx_vcount == 1 {
let idx_broadcast = idx.broadcast(cnt);
let cnt_broadcast = idx.const_like(cnt as i64).broadcast(cnt);
idx_broadcast.mul(&cnt_broadcast).add(&create_index_vector(0..cnt as i64))
} else {
let elements: SmallVec<[Arc<UOp>; 4]> = (0..idx_vcount)
.flat_map(|i| {
let lane = idx.gep(vec![i]);
let cnt_const = UOp::const_(lane.dtype(), ConstValue::Int(cnt as i64));
let scaled = lane.mul(&cnt_const);
(0..cnt).map(move |j| scaled.add(&UOp::const_(scaled.dtype(), ConstValue::Int(j as i64))))
})
.collect();
UOp::vectorize(elements)
};
let expanded_gate = if idx_vcount > 1 {
gate.as_ref().map(|g| {
if g.dtype().vcount() > 1 {
let elements: SmallVec<[Arc<UOp>; 4]> = (0..idx_vcount)
.flat_map(|i| {
let lane = g.gep(vec![i]);
std::iter::repeat_n(lane, cnt)
})
.collect();
UOp::vectorize(elements)
} else {
g.broadcast(total)
}
})
} else {
gate.clone()
};
Some(
UOp::index()
.buffer(buf_broadcast)
.indices(vec![final_idx])
.maybe_gate(expanded_gate)
.ptr(true)
.call()
.expect("ICE unable to create index"),
)
}
fn create_index_vector(values: impl IntoIterator<Item = i64>) -> Arc<UOp> {
let elements: SmallVec<[Arc<UOp>; 4]> = values.into_iter().map(UOp::index_const).collect();
UOp::vectorize(elements)
}
fn no_vectorized_index_precnt(
buf: &Arc<UOp>,
idx: &Arc<UOp>,
gate: &Option<Arc<UOp>>,
cnt: usize,
input_gep: &[usize],
) -> Option<Arc<UOp>> {
let precnt = input_gep.len();
let idx_vcount = idx.dtype().vcount();
if idx_vcount == 1 {
let total = cnt * precnt;
let gep_arg: Vec<usize> = (0..cnt).flat_map(|_| 0..precnt).collect();
let sum_arg = (0..cnt).flat_map(|i| input_gep.iter().map(move |&y| (i + y) as i64));
let buf_broadcast = buf.broadcast(total);
let final_idx =
idx.gep(gep_arg).mul(&idx.const_like(cnt as i64).broadcast(total)).add(&create_index_vector(sum_arg));
Some(
UOp::index()
.buffer(buf_broadcast)
.indices(vec![final_idx])
.maybe_gate(gate.clone())
.ptr(true)
.call()
.expect("ICE: unable to create index"),
)
} else {
let per_lane = cnt * precnt;
let total = per_lane * idx_vcount;
let buf_broadcast = buf.broadcast(total);
let elements: SmallVec<[Arc<UOp>; 4]> = (0..idx_vcount)
.flat_map(|i| {
let lane = idx.gep(vec![i]);
let cnt_const = UOp::const_(lane.dtype(), ConstValue::Int(cnt as i64));
let scaled = lane.mul(&cnt_const);
(0..cnt).flat_map(move |c| {
let s = scaled.clone();
input_gep.iter().map(move |&y| s.add(&UOp::const_(s.dtype(), ConstValue::Int((c + y) as i64))))
})
})
.collect();
let final_idx = UOp::vectorize(elements);
let expanded_gate = gate.as_ref().map(|g| {
if g.dtype().vcount() > 1 {
let elements: SmallVec<[Arc<UOp>; 4]> = (0..idx_vcount)
.flat_map(|i| {
let lane = g.gep(vec![i]);
std::iter::repeat_n(lane, per_lane)
})
.collect();
UOp::vectorize(elements)
} else {
g.broadcast(total)
}
});
Some(
UOp::index()
.buffer(buf_broadcast)
.indices(vec![final_idx])
.maybe_gate(expanded_gate)
.ptr(true)
.call()
.expect("ICE: unable to create index"),
)
}
}
pub fn load_store_indexing_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
index @ Index { buffer, indices, gate: Some(g) }
if matches!(g.op(), Op::Const(cv) if matches!(cv.0, ConstValue::Bool(true)))
~> UOp::new(Op::Index { buffer: buffer.clone(), indices: indices.clone(), gate: None }, index.dtype())
}
}
pub fn pm_add_loads() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
idx @ Index { buffer, .. } if !is_ptr_or_image_dtype(&idx.dtype()) => {
let new_idx = idx.with_dtype(buffer.dtype());
Some(UOp::load().buffer(buffer.clone()).index(new_idx).dtype(idx.dtype().scalar_dtype()).call())
},
Store { index: Load { index: inner_idx, .. }, value, ranges }
=> Some(inner_idx.store_with_ranges(value.clone(), ranges.clone())),
}
}
fn is_ptr_or_image_dtype(dtype: &DType) -> bool {
matches!(dtype, DType::Ptr { .. } | DType::Image { .. })
}
pub fn pm_wmma_accumulate() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Add(wmma @ Wmma { a, b, c, metadata }, add) => |wmma, a, b, c, metadata, add| {
if wmma.dtype() != add.dtype() {
return None;
}
let new_c = c.add(add);
Some(UOp::wmma(a.clone(), b.clone(), new_c, metadata.clone()))
},
Add(add, wmma @ Wmma { a, b, c, metadata }) => |wmma, add, a, b, c, metadata| {
if wmma.dtype() != add.dtype() {
return None;
}
let new_c = c.add(add);
Some(UOp::wmma(a.clone(), b.clone(), new_c, metadata.clone()))
},
}
}
pub fn load_store_folding_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
index if is_vector_index(index) => expand_index_to_vectorize(index),
midx @ Vectorize { elements } if elements.iter().all(|e| matches!(e.op(), Op::Index { .. }))
=> fold_expanded_index(midx),
load @ Load { buffer, index: Gep { vector, indices } }
=> move_gep_after_load(load, buffer, vector, indices),
Store { index: Gep { vector, indices }, value, ranges }
=> move_gep_on_store(vector, indices, value, ranges),
load @ Load { buffer, index: ptrcat @ PtrCat { sources } }
=> distribute_ptrcat_load(load, buffer, ptrcat, sources),
Store { index: PtrCat { sources }, value, ranges }
=> distribute_ptrcat_store(sources, value, ranges),
}
}
pub fn correct_load_store_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
ls @ Load { index: Cast { src: idx @ Index { buffer: _, .. }, .. }, .. }
=> split_load_store(ls, idx),
ls @ Store { index: Cast { src: idx @ Index { buffer: _, .. }, .. }, .. }
=> split_load_store(ls, idx),
ls @ Load { buffer: _, index: _, alt: _ } => image_fixup(ls),
ls @ Store { index: _, value: _, ranges: _ } => image_fixup(ls),
}
}
fn is_define_or_after(uop: &Arc<UOp>) -> bool {
matches!(uop.unwrap_after().op(), Op::DefineLocal(_) | Op::DefineReg { .. } | Op::Param { device: None, .. })
}
fn is_vector_index(uop: &Arc<UOp>) -> bool {
let Op::Index { buffer, indices, .. } = uop.op() else { return false };
let Some(idx) = indices.first() else { return false };
if idx.dtype().vcount() <= 1 {
return false;
}
let Op::Vectorize { elements } = buffer.op() else { return false };
!elements.is_empty() && elements.iter().all(is_define_or_after)
}
fn move_gep_after_load(
load: &Arc<UOp>,
buffer: &Arc<UOp>,
gep_inner: &Arc<UOp>,
gep_indices: &[usize],
) -> Option<Arc<UOp>> {
let new_dtype = load.dtype().scalar_dtype().vec(gep_indices.len());
let inner_load = load.replace().dtype(new_dtype).src(vec![buffer.clone(), gep_inner.clone()]).call();
Some(inner_load.gep(gep_indices.to_vec()))
}
fn move_gep_on_store(
gep_inner: &Arc<UOp>,
gep_indices: &[usize],
value: &Arc<UOp>,
ranges: &SmallVec<[Arc<UOp>; 4]>,
) -> Option<Arc<UOp>> {
let mut inverse_map: Vec<(usize, usize)> = gep_indices.iter().enumerate().map(|(i, &x)| (x, i)).collect();
inverse_map.sort_by_key(|&(x, _)| x);
let inverse_indices: Vec<usize> = inverse_map.iter().map(|&(_, i)| i).collect();
let reordered_value = value.gep(inverse_indices);
Some(gep_inner.store_with_ranges(reordered_value, ranges.clone()))
}
fn expand_index_to_vectorize(index: &Arc<UOp>) -> Option<Arc<UOp>> {
let Op::Index { buffer, indices, gate } = index.op() else { return None };
assert!(indices.len() <= 1, "ICE: expand_index_to_vectorize called with multi-index INDEX (len={})", indices.len());
let vec = indices.first()?;
let count = vec.dtype().vcount();
let buf = if let Op::Vectorize { elements } = buffer.op() { elements.first()?.clone() } else { buffer.clone() };
let scalar_indices: Vec<_> = (0..count)
.map(|i| {
let lane_gate = gate.as_ref().map(|g| if g.dtype().vcount() > 1 { g.gep(vec![i]) } else { g.clone() });
UOp::index()
.buffer(buf.clone())
.indices(vec![vec.gep(vec![i])])
.maybe_gate(lane_gate)
.ptr(true)
.call()
.expect("ICE: unable to create index")
})
.collect();
Some(UOp::vectorize(scalar_indices.into()))
}
fn fold_expanded_index(midx: &Arc<UOp>) -> Option<Arc<UOp>> {
let Op::Vectorize { elements: sources } = midx.op() else { return None };
let count = sources.len();
if count == 0 {
return None;
}
let first_buf = match sources[0].op() {
Op::Index { buffer, .. } => buffer,
_ => return None,
};
if !sources.iter().all(|s| matches!(s.op(), Op::Index { buffer, .. } if Arc::ptr_eq(buffer, first_buf))) {
return None;
}
let buf = first_buf;
struct LaneData {
valid: Arc<UOp>,
root: Arc<UOp>,
offset: i64,
gate_id: u64,
}
let mut lane_data: Vec<(usize, LaneData)> = Vec::with_capacity(count);
for (lane, idx_op) in sources.iter().enumerate() {
let Op::Index { indices: simp_indices, gate: lane_gate, .. } = idx_op.op() else { continue };
let idx = simp_indices.first()?.get_idx();
let valid = simp_indices.first()?.get_valid();
let gate_id = lane_gate.as_ref().map_or(u64::MAX, |g| g.id);
let (root, offset) = match idx.op() {
Op::Invalid => (UOp::invalid_marker(), 0),
Op::Binary(BinaryOp::Add, l, r) if matches!(r.op(), Op::Const(cv) if matches!(cv.0, ConstValue::Int(_))) => {
let Op::Const(cv) = r.op() else { unreachable!() };
let ConstValue::Int(off) = cv.0 else { unreachable!() };
(l.clone(), off)
}
Op::Binary(BinaryOp::Add, l, r) if matches!(l.op(), Op::Const(cv) if matches!(cv.0, ConstValue::Int(_))) => {
let Op::Const(cv) = l.op() else { unreachable!() };
let ConstValue::Int(off) = cv.0 else { unreachable!() };
(r.clone(), off)
}
Op::Const(cv) if matches!(cv.0, ConstValue::Int(_)) => {
let ConstValue::Int(off) = cv.0 else { unreachable!() };
(UOp::index_const(0), off)
}
_ => (idx.clone(), 0),
};
lane_data.push((lane, LaneData { valid, root, offset, gate_id }));
}
let mut offsets_by_root: HashMap<(u64, u64, u64), HashMap<i64, Vec<usize>>> = HashMap::new();
for (lane, data) in &lane_data {
let key = (data.valid.id, data.root.id, data.gate_id);
offsets_by_root.entry(key).or_default().entry(data.offset).or_default().push(*lane);
}
let mut ret = Vec::new();
let mut idxs: Vec<Option<usize>> = vec![None; count];
let mut global_offset = 0;
for offsets in offsets_by_root.values() {
let groups = group_consecutive_offsets_from_map(offsets);
for grp in groups {
let lidx = sources[offsets[&grp[0]][0]].clone();
let ptr = if grp.len() > 1 { lidx.cast(make_vec_ptr_dtype(buf, grp.len())) } else { lidx };
for (i, &offset) in grp.iter().enumerate() {
for &lane in &offsets[&offset] {
idxs[lane] = Some(global_offset + i);
}
}
ret.push(ptr);
global_offset += grp.len();
}
}
if idxs.iter().any(|x| x.is_none()) {
return None;
}
let DType::Ptr { base, addrspace, size, .. } = buf.dtype().clone() else { return None };
let scalar_ptr = DType::Ptr { base: Box::new(DType::Scalar(base.scalar()?)), addrspace, size, vcount: 1 };
let ptrcat_dtype = scalar_ptr.vec(global_offset);
let ptrcat = UOp::ptrcat().sources(ret).dtype(ptrcat_dtype).call();
let gep_indices: Vec<usize> = idxs.into_iter().map(|x| x.unwrap()).collect();
Some(ptrcat.gep(gep_indices))
}
fn group_consecutive_offsets_from_map(offsets_map: &HashMap<i64, Vec<usize>>) -> Vec<Vec<i64>> {
if offsets_map.is_empty() {
return vec![];
}
let sorted: Vec<_> = offsets_map.keys().copied().sorted().collect();
sorted
.iter()
.copied()
.enumerate()
.chunk_by(|(idx, offset)| offset - (*idx as i64))
.into_iter()
.map(|(_, group)| group.map(|(_, offset)| offset).collect())
.collect()
}
fn make_vec_ptr_dtype(buffer: &Arc<UOp>, vec_len: usize) -> DType {
let (base_dtype, addrspace) = buffer
.ptrdtype()
.map(|(base, addrspace, _)| (base.base(), addrspace))
.unwrap_or_else(|| (buffer.dtype().base(), AddrSpace::Global));
let vec_dtype = DType::Vector { scalar: base_dtype, count: vec_len };
DType::Ptr { base: Box::new(vec_dtype), addrspace, size: Some(vec_len), vcount: 1 }
}
fn distribute_ptrcat_load(
load: &Arc<UOp>,
buffer: &Arc<UOp>,
ptrcat: &Arc<UOp>,
sources: &[Arc<UOp>],
) -> Option<Arc<UOp>> {
let loads: Vec<Arc<UOp>> = sources
.iter()
.enumerate()
.map(|(i, ptr)| {
let load_dtype = match ptr.dtype() {
DType::Ptr { base, .. } => base.as_ref().clone(),
other => other.clone(),
};
let scalar_buf = match buffer.op() {
Op::Vectorize { elements, .. } => elements.get(i).cloned().unwrap_or_else(|| buffer.clone()),
_ => buffer.clone(),
};
let alt = match load.op() {
Op::Load { alt, .. } => alt.clone(),
_ => None,
};
UOp::load().buffer(scalar_buf).index(ptr.clone()).maybe_alt(alt).dtype(load_dtype).call()
})
.collect();
let cat_dtype = DType::Scalar(ptrcat.dtype().base()).vec(ptrcat.dtype().vcount());
Some(UOp::cat().sources(loads).dtype(cat_dtype).call())
}
fn distribute_ptrcat_store(
sources: &[Arc<UOp>],
value: &Arc<UOp>,
ranges: &SmallVec<[Arc<UOp>; 4]>,
) -> Option<Arc<UOp>> {
let value_vcount = value.dtype().vcount();
let mut stores = Vec::new();
let mut offset = 0usize;
for ptr in sources.iter() {
let ptr_count = ptr_element_count(ptr);
debug_assert!(offset + ptr_count <= value_vcount, "PTRCAT size mismatch");
let gep_indices: Vec<usize> = (offset..offset + ptr_count).collect();
let store_value = value.gep(gep_indices);
stores.push(ptr.store_with_ranges(store_value, ranges.clone()));
offset += ptr_count;
}
Some(UOp::group(stores.into_iter().collect()))
}
fn ptr_element_count(ptr: &Arc<UOp>) -> usize {
match ptr.dtype() {
DType::Ptr { base, .. } => base.vcount(),
_ => 1,
}
}
fn split_load_store(ls: &Arc<UOp>, idx: &Arc<UOp>) -> Option<Arc<UOp>> {
let Op::Index { buffer: buf, indices, .. } = idx.op() else { return None };
let sz = match ls.op() {
Op::Load { index, .. } | Op::Store { index, .. } => ptr_element_count(index),
_ => return None,
};
if sz == 1 {
return None;
}
let buf_dtype = buf.dtype();
static IS_AMX: std::sync::LazyLock<bool> =
std::sync::LazyLock::new(|| std::env::var("MOROK_AMX").is_ok_and(|v| v == "1"));
let is_amx = *IS_AMX;
fn is_amx_tc_reg_ptr(dtype: &DType, sz: usize) -> bool {
sz >= 16
&& dtype.base().is_float()
&& matches!(dtype, DType::Ptr { addrspace: AddrSpace::Reg, .. } | DType::Vector { .. })
}
fn find_underlying_load(uop: &Arc<UOp>) -> Option<Arc<UOp>> {
match uop.op() {
Op::Gep { vector, .. } => find_underlying_load(vector),
Op::Load { .. } => Some(uop.clone()),
_ => None,
}
}
let is_amx_tc_acc = match ls.op() {
Op::Store { value, .. } => {
if let Some(load) = find_underlying_load(value) {
if let Op::Load { index, .. } = load.op() {
if let Op::Index { buffer, .. } = index.op() {
let buf_dtype = buffer.dtype();
is_amx && is_amx_tc_reg_ptr(&buf_dtype, sz)
} else {
false
}
} else {
false
}
} else {
false
}
}
Op::Load { .. } => is_amx && is_amx_tc_reg_ptr(&buf_dtype, sz),
_ => false,
};
let no_fold = (!buf_dtype.base().is_float() && !matches!(buf_dtype, DType::Image { .. }))
|| (matches!(buf_dtype, DType::Ptr { addrspace: AddrSpace::Reg, .. }) && !is_amx_tc_acc);
let mut lengths = if no_fold {
vec![1]
} else if matches!(buf_dtype, DType::Image { .. }) {
vec![4, 1]
} else if is_amx {
vec![16, 8, 4, 2, 1] } else {
vec![4, 2, 1]
};
if !is_amx_tc_acc && let Some(offset) = indices.first() {
lengths.retain(|&len| offset_divides_evenly(offset, len));
}
let scalar_dtype = buf_dtype.scalar_dtype();
let mut ret = Vec::new();
let mut pos = 0usize;
while pos < sz {
for &fold_len in &lengths {
if pos + fold_len > sz {
continue;
}
let lidx = if pos == 0 { idx.clone() } else { offset_index(idx, pos as i64) };
let lidx = if fold_len > 1 { lidx.cast(make_vec_ptr_dtype(buf, fold_len)) } else { lidx };
match ls.op() {
Op::Store { value, ranges, .. } => {
ret.push(lidx.store_with_ranges(value.gep((pos..pos + fold_len).collect()), ranges.clone()));
}
Op::Load { buffer, .. } => {
ret.push(UOp::load().buffer(buffer.clone()).index(lidx).dtype(scalar_dtype.vec(fold_len)).call());
}
_ => return None,
}
pos += fold_len;
break;
}
}
if ret.len() <= 1 {
return None;
}
match ls.op() {
Op::Load { .. } => Some(UOp::cat().sources(ret).dtype(scalar_dtype.vec(sz)).call()),
Op::Store { .. } => Some(UOp::group(ret.into_iter().collect())),
_ => None,
}
}
fn offset_divides_evenly(offset: &Arc<UOp>, len: usize) -> bool {
if len == 0 {
return false;
}
if len == 1 {
return true;
}
let v = len as i64;
match offset.op() {
Op::Const(cv) => matches!(cv.0, ConstValue::Int(n) if n % v == 0),
Op::VConst { values } => values.iter().all(|val| matches!(val, ConstValue::Int(n) if n % v == 0)),
Op::Binary(BinaryOp::Add, left, right) => offset_divides_evenly(left, len) && offset_divides_evenly(right, len),
Op::Binary(BinaryOp::Mul, left, right) => {
let check_const =
|c: &Arc<UOp>| matches!(c.op(), Op::Const(cv) if matches!(cv.0, ConstValue::Int(n) if n % v == 0));
check_const(left)
|| check_const(right)
|| offset_divides_evenly(left, len)
|| offset_divides_evenly(right, len)
}
_ => false,
}
}
fn offset_index(idx: &Arc<UOp>, offset: i64) -> Arc<UOp> {
let Op::Index { buffer, indices, gate } = idx.op() else {
return idx.clone();
};
let new_indices: SmallVec<[Arc<UOp>; 4]> = indices
.iter()
.enumerate()
.map(|(i, index_expr)| if i == 0 { index_expr.add(&index_expr.const_like(offset)) } else { index_expr.clone() })
.collect();
UOp::index()
.buffer(buffer.clone())
.indices(new_indices)
.maybe_gate(gate.clone())
.ptr(true)
.call()
.expect("ICE: unable to create index")
}
fn image_fixup(ls: &Arc<UOp>) -> Option<Arc<UOp>> {
let (index, is_load) = match ls.op() {
Op::Load { index, .. } => (index, true),
Op::Store { index, .. } => (index, false),
_ => return None,
};
if let Op::Cast { src: inner_idx, dtype: cast_dtype } = index.op()
&& let Op::Index { buffer: img_buf, indices, gate } = inner_idx.op()
{
let DType::Image { shape, .. } = img_buf.dtype() else { return None };
if cast_dtype.vcount() != 4 {
return None;
}
let lin_idx = indices.first()?;
let x = lin_idx.get_idx();
let valid = lin_idx.get_valid();
let width = shape.get(1).copied().unwrap_or(1) as i64;
let four = UOp::index_const(4);
let width_const = UOp::index_const(width);
let stride = UOp::index_const(4 * width);
let x_coord = x.idiv(&four).mod_(&width_const);
let y_coord = x.idiv(&stride);
let oidx = UOp::vectorize(smallvec::smallvec![x_coord, y_coord]);
let new_idx_expr = if matches!(valid.op(), Op::Const(cv) if cv.0 == ConstValue::Bool(true)) {
oidx
} else {
oidx.valid(valid)
};
let new_idx = if matches!(inner_idx.dtype(), DType::Ptr { .. }) {
UOp::index()
.buffer(img_buf.clone())
.indices(vec![new_idx_expr])
.maybe_gate(gate.clone())
.ptr(true)
.call()
.ok()?
} else {
UOp::index()
.buffer(img_buf.clone())
.indices(vec![new_idx_expr])
.maybe_gate(gate.clone())
.dtype(inner_idx.dtype())
.call()
.ok()?
};
return Some(ls.replace().src(vec![new_idx]).call());
}
if let Op::Index { buffer: img_buf, indices, gate } = index.op() {
let DType::Image { shape, .. } = img_buf.dtype() else { return None };
let lin_idx = indices.first()?;
let x = lin_idx.get_idx();
if x.dtype().vcount() == 2 {
return None; }
if !is_load {
tracing::warn!("image_fixup: STORE with unfoldable image not supported");
return None;
}
let valid = lin_idx.get_valid();
let width = shape.get(1).copied().unwrap_or(1) as i64;
let four = UOp::index_const(4);
let width_const = UOp::index_const(width);
let stride = UOp::index_const(4 * width);
let x_coord = x.idiv(&four).mod_(&width_const);
let y_coord = x.idiv(&stride);
let oidx = UOp::vectorize(smallvec::smallvec![x_coord, y_coord]);
let new_idx_expr = if matches!(valid.op(), Op::Const(cv) if cv.0 == ConstValue::Bool(true)) {
oidx
} else {
oidx.valid(valid)
};
let new_idx = if matches!(index.dtype(), DType::Ptr { .. }) {
UOp::index()
.buffer(img_buf.clone())
.indices(vec![new_idx_expr])
.maybe_gate(gate.clone())
.ptr(true)
.call()
.ok()?
} else {
UOp::index()
.buffer(img_buf.clone())
.indices(vec![new_idx_expr])
.maybe_gate(gate.clone())
.dtype(index.dtype())
.call()
.ok()?
};
let vec4_dtype = ls.dtype().vec(4);
let vec_load = UOp::load().buffer(ls.load_buffer()?).index(new_idx).dtype(vec4_dtype).call();
let x_mod_4 = x.mod_(&four);
let nan = ls.const_like(ConstValue::Float(f64::NAN));
let result = (0..4).rev().fold(nan, |ret, i| {
let i_const = UOp::index_const(i);
let not_eq = x_mod_4.ne(&i_const);
let gep_i = vec_load.gep(vec![i as usize]);
UOp::try_where(not_eq, ret, gep_i).expect("WHERE")
});
return Some(result);
}
None
}
use crate::symbolic::dce::reduce_identity;
pub fn pm_reduce() -> TypedPatternMatcher<ReduceContext> {
crate::patterns! {
@context ReduceContext;
red @ Reduce(_, ..) => {
reduce_to_acc(red, ctx)
},
Sink { sources: _sources } => {
ctx.merge_reduce_ends(_sources)
},
}
}
fn horizontal_reduce(inp: &Arc<UOp>, out_dtype: &DType) -> Vec<Arc<UOp>> {
if inp.dtype() == *out_dtype {
return vec![inp.clone()];
}
let inp_vcount = inp.dtype().vcount();
let out_vcount = out_dtype.vcount();
assert!(
inp_vcount.is_multiple_of(out_vcount),
"horizontal mismatch: inp.dtype={:?} (vcount={}), out_dtype={:?} (vcount={})",
inp.dtype(),
inp_vcount,
out_dtype,
out_vcount
);
let horizontal_amount = inp_vcount / out_vcount;
(0..horizontal_amount).map(|i| inp.gep((i..inp_vcount).step_by(horizontal_amount).collect())).collect()
}
fn reduce_to_acc(red: &Arc<UOp>, ctx: &mut ReduceContext) -> Option<Arc<UOp>> {
let Op::Reduce { src: inp, ranges: reduce_range, reduce_op } = red.op() else { return None };
let lst = horizontal_reduce(inp, &red.dtype());
debug_assert!(lst.iter().all(|x| x.dtype() == red.dtype()), "horizontal reduction mismatch");
if reduce_range.is_empty() {
return lst.into_iter().reduce(|a, b| apply_reduce_binary(*reduce_op, a, b, &red.dtype()));
}
let topo = inp.toposort();
let ended: HashSet<u64> = topo
.iter()
.filter_map(|n| if let Op::End { ranges, .. } = n.op() { Some(ranges.iter().map(|r| r.id)) } else { None })
.flatten()
.collect();
let reduce_ids: HashSet<u64> = reduce_range.iter().map(|r| r.id).collect();
let input_ranges: SmallVec<[Arc<UOp>; 4]> = topo
.iter()
.filter(|n| matches!(n.op(), Op::Range { .. }) && !reduce_ids.contains(&n.id) && !ended.contains(&n.id))
.cloned()
.collect();
let identity = reduce_identity(*reduce_op, red.dtype());
let acc = UOp::define_reg_typed(1, red.dtype());
let zero = UOp::index_const(0);
let make_idx = |buf: Arc<UOp>| UOp::index().buffer(buf).indices(vec![zero.clone()]).call().expect("index");
let acc_init = make_idx(acc.after(input_ranges)).store_value(identity);
let mut loop_deps: SmallVec<[Arc<UOp>; 4]> = smallvec::smallvec![acc_init];
loop_deps.extend(reduce_range.iter().cloned());
let acc_loop = make_idx(acc.after(loop_deps));
let lst_with_acc = std::iter::once(acc_loop).chain(lst);
let ret = lst_with_acc.reduce(|a, b| apply_reduce_binary(*reduce_op, a, b, &red.dtype()))?;
let store_end = make_idx(acc.clone()).store_value(ret).end(reduce_range.clone());
ctx.register_end(&store_end);
Some(make_idx(acc.after(smallvec::smallvec![store_end])))
}
fn apply_reduce_binary(reduce_op: ReduceOp, a: Arc<UOp>, b: Arc<UOp>, dtype: &DType) -> Arc<UOp> {
debug_assert!(a.dtype() == b.dtype(), "reduce operand dtype mismatch");
match reduce_op {
ReduceOp::Add => UOp::new(Op::Binary(BinaryOp::Add, a, b), dtype.clone()),
ReduceOp::Mul => UOp::new(Op::Binary(BinaryOp::Mul, a, b), dtype.clone()),
ReduceOp::Max => UOp::new(Op::Binary(BinaryOp::Max, a, b), dtype.clone()),
ReduceOp::Min => {
let cond = UOp::new(Op::Binary(BinaryOp::Lt, a.clone(), b.clone()), DType::Bool.vec(dtype.vcount()));
UOp::try_where(cond, a, b).expect("WHERE")
}
}
}