use std::collections::HashMap;
use std::sync::Arc;
use morok_dtype::DType;
use morok_ir::prelude::*;
use morok_ir::{AxisId, AxisType};
use crate::TypedPatternMatcher;
use smallvec::{SmallVec, smallvec};
fn const_to_usize(cv: &ConstValue) -> Option<usize> {
match cv {
ConstValue::Int(i) if *i > 0 => Some(*i as usize),
ConstValue::UInt(u) => Some(*u as usize),
_ => None,
}
}
fn broadcast_info(uop: &Arc<UOp>) -> Option<(Arc<UOp>, usize)> {
let Op::Vectorize { elements } = uop.op() else { return None };
let first = elements.first()?;
elements.iter().skip(1).all(|e| Arc::ptr_eq(e, first)).then(|| (first.clone(), elements.len()))
}
fn fix_bufferize_unroll(bufferize: &Arc<UOp>) -> Option<Arc<UOp>> {
let Op::Bufferize { compute, ranges, opts } = bufferize.op() else {
return None;
};
let Op::Unroll { .. } = compute.op() else {
return None;
};
if ranges.len() != 1 {
return None;
}
let Op::Unroll { src: range_inner, unroll_axes: contract_axes } = ranges[0].op() else {
return None;
};
let inner_count = range_inner.dtype().vcount();
let contracted_compute = UOp::new(
Op::Contract { src: compute.clone(), upcast_ranges: contract_axes.clone() },
compute.dtype().vec(inner_count),
);
let contracted_range = UOp::new(
Op::Contract { src: ranges[0].clone(), upcast_ranges: contract_axes.clone() },
ranges[0].dtype().vec(inner_count),
);
Some(UOp::new(
Op::Bufferize { compute: contracted_compute, ranges: smallvec![contracted_range], opts: opts.clone() },
bufferize.dtype(),
))
}
pub fn expand_arg_to_idx(args: &[(usize, usize)], rpk: &HashMap<usize, usize>) -> usize {
let mut idx = 0;
let mut mul = 1;
for &(axis, m) in args.iter().rev() {
idx += rpk.get(&axis).unwrap_or(&0) * mul;
mul *= m;
}
idx
}
pub fn choices_from_args(args: &[(usize, usize)]) -> Vec<HashMap<usize, usize>> {
let mut result = vec![HashMap::new()];
for &(axis, m) in args {
result = result
.into_iter()
.flat_map(|map| {
(0..m).map(move |v| {
let mut new_map = map.clone();
new_map.insert(axis, v);
new_map
})
})
.collect();
}
result
}
pub fn swizzle_args(cargs: &[(usize, usize)], eargs: &[(usize, usize)], exclude_args: &[usize]) -> Vec<usize> {
choices_from_args(cargs)
.into_iter()
.map(|rpk| {
let mut rpk_with_zeros = rpk.clone();
for &ax in exclude_args {
rpk_with_zeros.insert(ax, 0);
}
expand_arg_to_idx(eargs, &rpk_with_zeros)
})
.collect()
}
pub fn pre_expand(ast: &Arc<UOp>) -> Arc<UOp> {
use crate::rewrite::graph_rewrite;
use crate::symbolic::patterns::sym;
let ast = graph_rewrite(phase1_range_to_unroll(), ast.clone(), &mut ());
use std::sync::LazyLock;
static PHASE2: LazyLock<TypedPatternMatcher> =
LazyLock::new(|| sym().clone() + pm_pre_expander() + pm_group_for_reduce() + expander());
graph_rewrite(&*PHASE2, ast, &mut ())
}
fn phase1_range_to_unroll() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
range @ Range { end: _end @const(cv), axis_id, axis_type }
if matches!(axis_type, AxisType::Unroll) => |range| {
let size = const_to_usize(&cv)?;
let values: Vec<ConstValue> = (0..size as i64).map(ConstValue::Int).collect();
let vconst = UOp::vconst(values, range.dtype().scalar_dtype());
Some(vconst.unroll_with_dtype(vec![(axis_id.value(), size)], range.dtype()))
},
}
}
pub fn pm_pre_expander() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
range @ Range { end: _end @const(cv), axis_id, axis_type }
if matches!(axis_type, AxisType::Upcast | AxisType::Unroll) => |range| {
let size = const_to_usize(&cv)?;
let values: Vec<ConstValue> = (0..size as i64).map(ConstValue::Int).collect();
let vconst = UOp::vconst(values, range.dtype().scalar_dtype());
Some(vconst.unroll_with_dtype(vec![(axis_id.value(), size)], range.dtype()))
},
reduce @ Reduce(_, ..) => |reduce| fix_reduce_unroll(reduce),
store if matches!(store.op(), Op::Store { .. }) => |store| fix_store_unroll(store),
}
}
pub fn expander() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
After { passthrough, deps, .. } if broadcast_info(passthrough).is_some() => |after| {
let (src, count) = broadcast_info(passthrough)?;
let elements: SmallVec<[Arc<UOp>; 4]> = std::iter::repeat_n(src.after(deps.clone()), count).collect();
Some(UOp::vectorize(elements))
},
End { computation, ranges, .. } if broadcast_info(computation).is_some() => |end| {
let (src, count) = broadcast_info(computation)?;
let elements: SmallVec<[Arc<UOp>; 4]> = std::iter::repeat_n(src.end(ranges.clone()), count).collect();
Some(UOp::vectorize(elements))
},
end @ End(_, ..) => |end| end_unrolls(end),
bufferize if matches!(bufferize.op(), Op::Bufferize { .. }) => |bufferize| {
fix_bufferize_unroll(bufferize)
},
outer @ Unroll { src: Unroll { src: inner_src, unroll_axes: inner_axes, .. }, unroll_axes: outer_axes, .. } => |outer| {
let combined: Vec<(usize, usize)> = inner_axes.iter().chain(outer_axes.iter()).cloned().collect();
Some(inner_src.unroll_with_dtype(combined, outer.dtype()))
},
op if op.op().is_expandable() && has_unroll_input(op) => |op| do_expand(op),
contract @ Contract(_, ..) => |contract| do_contract(contract),
Barrier { src: Unroll { src: inner, unroll_axes, .. }, deps, .. } => |barrier| {
let inner_barrier = UOp::new(Op::Barrier { src: inner.clone(), deps: deps.clone() }, inner.dtype());
Some(inner_barrier.unroll(unroll_axes.clone()))
},
unroll @ Unroll { src, .. } if matches!(src.op(), Op::Vectorize { .. }) => |unroll| fuse_unroll_gep_alu(unroll),
Unroll { src, unroll_axes, .. } if unroll_axes.is_empty() ~> src,
}
}
fn fuse_unroll_gep_alu(unroll: &Arc<UOp>) -> Option<Arc<UOp>> {
let Op::Unroll { src, unroll_axes } = unroll.op() else { return None };
let Op::Vectorize { elements } = src.op() else { return None };
if elements.is_empty() {
return None;
}
let first = elements.first()?;
let Op::Binary(bin_op, first_a, first_b) = first.op() else { return None };
let Op::Gep { vector: base_x, indices: first_a_idx } = first_a.op() else { return None };
let Op::Gep { vector: base_y, indices: first_b_idx } = first_b.op() else { return None };
if first_a_idx.len() != 1 || first_b_idx.len() != 1 {
return None;
}
if first_a_idx[0] != 0 || first_b_idx[0] != 0 {
return None;
}
for (i, elem) in elements.iter().enumerate().skip(1) {
let Op::Binary(op, a, b) = elem.op() else { return None };
if op != bin_op {
return None;
}
let Op::Gep { vector: x, indices: a_idx } = a.op() else { return None };
let Op::Gep { vector: y, indices: b_idx } = b.op() else { return None };
if a_idx.len() != 1 || b_idx.len() != 1 {
return None;
}
if a_idx[0] != i || b_idx[0] != i {
return None;
}
if !Arc::ptr_eq(x, base_x) || !Arc::ptr_eq(y, base_y) {
return None;
}
}
let fused = UOp::new(Op::Binary(*bin_op, base_x.clone(), base_y.clone()), base_x.dtype());
let new_elements: SmallVec<[Arc<UOp>; 4]> = (0..elements.len()).map(|i| fused.gep(vec![i])).collect();
Some(UOp::vectorize(new_elements).unroll_with_dtype(unroll_axes.clone(), unroll.dtype()))
}
fn has_unroll_input(uop: &Arc<UOp>) -> bool {
uop.op().sources().iter().any(|src| matches!(src.op(), Op::Unroll { .. }))
}
fn do_expand(uop: &Arc<UOp>) -> Option<Arc<UOp>> {
let op = uop.op();
if !op.is_expandable() {
return None;
}
let sources = op.sources();
let unroll_sources: Vec<(usize, &Arc<UOp>)> =
sources.iter().enumerate().filter(|(_, s)| matches!(s.op(), Op::Unroll { .. })).collect();
if unroll_sources.is_empty() {
return None;
}
let exclude_args: Vec<usize> = if let Op::Wmma { metadata, .. } = op {
let mut ids = metadata.upcast_axes.all_axis_ids();
ids.extend(metadata.reduce_axes.iter());
ids.sort_unstable();
ids.dedup();
ids
} else {
vec![]
};
let all_expand_args: Vec<Vec<(usize, usize)>> = unroll_sources
.iter()
.filter_map(|(_, s)| if let Op::Unroll { unroll_axes, .. } = s.op() { Some(unroll_axes.clone()) } else { None })
.collect();
let expand_args: Vec<(usize, usize)> =
if all_expand_args.iter().all(|a| a == &all_expand_args[0]) && exclude_args.is_empty() {
all_expand_args[0].clone()
} else {
let mut combined: Vec<(usize, usize)> = all_expand_args.into_iter().flatten().collect();
combined.sort_by_key(|(ax, _)| *ax);
combined.dedup();
combined.into_iter().filter(|(ax, _)| !exclude_args.contains(ax)).collect()
};
let expand_sz: usize = expand_args.iter().map(|(_, sz)| sz).product();
if expand_sz == 0 {
return None;
}
let range_start_idx = op.range_ending_src_index();
let mut new_sources: SmallVec<[Arc<UOp>; 4]> = SmallVec::new();
for (i, src) in sources.iter().enumerate() {
if let Op::Unroll { src: inner, unroll_axes: src_axes } = src.op() {
if *src_axes == expand_args {
new_sources.push(inner.clone());
} else {
let swizzle_indices = swizzle_args(&expand_args, src_axes, &exclude_args);
let wrapper_count = src.dtype().vcount();
let final_indices: Vec<usize> = if wrapper_count > 1 {
swizzle_indices
.iter()
.flat_map(|&idx| (0..wrapper_count).map(move |j| idx * wrapper_count + j))
.collect()
} else {
swizzle_indices
};
new_sources.push(inner.gep(final_indices));
}
} else {
if let Some(range_idx) = range_start_idx
&& i >= range_idx
{
new_sources.push(src.clone());
continue;
}
if i >= 1 && matches!(op, Op::Index { .. }) && !matches!(uop.dtype(), DType::Ptr { .. }) {
new_sources.push(src.clone());
continue;
}
let src_count = src.dtype().vcount();
if src_count > 1 {
let cat_sources: Vec<Arc<UOp>> = (0..expand_sz).map(|_| src.clone()).collect();
new_sources.push(UOp::cat().sources(cat_sources).call());
} else {
new_sources.push(src.broadcast(expand_sz));
}
}
}
let base_dtype = uop.dtype();
let base_count = base_dtype.vcount();
let new_dtype = if let Some(scalar) = base_dtype.scalar() {
DType::Scalar(scalar).vec(base_count * expand_sz)
} else {
base_dtype.clone()
};
if let Op::Gep { indices, .. } = op {
debug_assert_eq!(base_dtype.vcount(), 1, "GEP expansion expects scalar output dtype");
let src = new_sources.first()?;
let src_count = src.dtype().vcount();
let stride = src_count / expand_sz;
let new_indices: Vec<usize> =
indices.iter().flat_map(|&idx| (0..expand_sz).map(move |e| idx + e * stride)).collect();
let gep_result = src.gep(new_indices);
return Some(gep_result.unroll(expand_args));
}
let new_op = uop.replace().dtype(new_dtype.clone()).src(new_sources.to_vec()).call();
Some(new_op.unroll_with_dtype(expand_args, base_dtype))
}
pub(crate) fn fix_reduce_unroll(reduce: &Arc<UOp>) -> Option<Arc<UOp>> {
let Op::Reduce { src, reduce_op, ranges } = reduce.op() else {
return None;
};
let (reduce_range, reduce_expand): (Vec<_>, Vec<_>) =
ranges.iter().partition(|r| matches!(r.op(), Op::Range { .. }));
if reduce_expand.is_empty() {
return None;
}
let reduce_expand: Vec<_> = reduce_expand.into_iter().filter(|r| !matches!(r.op(), Op::Const(_))).collect();
if reduce_expand.is_empty() {
return None;
}
debug_assert!(
reduce_expand.iter().all(|r| matches!(r.op(), Op::Unroll { .. })),
"not all UNROLLS in {:?}",
reduce_expand.iter().map(|r| r.op().as_ref()).collect::<Vec<_>>()
);
let contract_axes: Vec<(usize, usize)> = reduce_expand
.iter()
.filter_map(|u| match u.op() {
Op::Unroll { unroll_axes, .. } => Some(unroll_axes.clone()),
_ => None,
})
.flatten()
.collect();
let contracted_src = if !contract_axes.is_empty() {
let total: usize = contract_axes.iter().map(|(_, sz)| sz).product();
UOp::new(Op::Contract { src: src.clone(), upcast_ranges: contract_axes }, reduce.dtype().vec(total))
} else {
src.clone()
};
Some(UOp::new(
Op::Reduce { src: contracted_src, ranges: reduce_range.into_iter().cloned().collect(), reduce_op: *reduce_op },
reduce.dtype(),
))
}
fn fix_store_unroll(store: &Arc<UOp>) -> Option<Arc<UOp>> {
match store.op() {
Op::Store { index, value, ranges } => {
let (store_expand, store_range): (Vec<_>, Vec<_>) =
ranges.iter().partition(|r| matches!(r.op(), Op::Unroll { .. }));
if store_expand.is_empty() {
return None;
}
let contract_axes: Vec<(usize, usize)> = store_expand
.iter()
.filter_map(|u| match u.op() {
Op::Unroll { unroll_axes, .. } => Some(unroll_axes.clone()),
_ => None,
})
.flatten()
.collect();
let new_store = index.store_with_ranges(value.clone(), store_range.into_iter().cloned().collect());
Some(new_store.contract(contract_axes))
}
_ => None,
}
}
fn end_unrolls(uop: &Arc<UOp>) -> Option<Arc<UOp>> {
let Op::End { computation, ranges } = uop.op() else { return None };
let (unrolls, non_unrolls): (Vec<_>, Vec<_>) = ranges.iter().partition(|r| matches!(r.op(), Op::Unroll { .. }));
if unrolls.is_empty() {
return None;
}
let all_axes: Vec<(usize, usize)> = unrolls
.iter()
.filter_map(|u| match u.op() {
Op::Unroll { unroll_axes, .. } => Some(unroll_axes.clone()),
_ => None,
})
.flatten()
.collect();
let contracted = computation.contract(all_axes);
Some(UOp::new(Op::End { computation: contracted, ranges: non_unrolls.into_iter().cloned().collect() }, uop.dtype()))
}
fn do_contract(uop: &Arc<UOp>) -> Option<Arc<UOp>> {
let Op::Contract { src: contract_src, upcast_ranges: contract_axes } = uop.op() else {
return None;
};
let Op::Unroll { src: unroll_inner, unroll_axes } = contract_src.op() else {
let count = uop.dtype().vcount();
if count == 1 {
return Some(contract_src.clone());
}
let sources: SmallVec<[Arc<UOp>; 4]> = (0..count).map(|_| contract_src.clone()).collect();
return Some(UOp::vectorize(sources));
};
debug_assert!(
uop.dtype() == DType::Void || uop.dtype().vcount() == contract_axes.iter().map(|(_, sz)| sz).product::<usize>(),
"Contract dtype count mismatch"
);
let remaining_axes: Vec<_> =
unroll_axes.iter().filter(|(ax, _)| !contract_axes.iter().any(|(cax, _)| cax == ax)).cloned().collect();
let gep_indices = contract_gep_indices(contract_axes, unroll_axes, &remaining_axes);
let gep_result = unroll_inner.gep(gep_indices);
Some(gep_result.unroll_with_dtype(remaining_axes, uop.dtype()))
}
fn contract_gep_indices(
contract_axes: &[(usize, usize)],
unroll_axes: &[(usize, usize)],
remaining_axes: &[(usize, usize)],
) -> Vec<usize> {
let remaining_choices = choices_from_args(remaining_axes);
let contract_choices = choices_from_args(contract_axes);
let mut indices = Vec::new();
for rpk in &remaining_choices {
for lrpk in &contract_choices {
let mut merged = rpk.clone();
merged.extend(lrpk);
indices.push(expand_arg_to_idx(unroll_axes, &merged));
}
}
indices
}
fn fix_group_for_reduce(reduce: &Arc<UOp>) -> Option<Arc<UOp>> {
let Op::Reduce { src, reduce_op, ranges } = reduce.op() else {
return None;
};
let (reduce_gfr, reduce_r): (Vec<_>, Vec<_>) =
ranges.iter().partition(|r| matches!(r.op(), Op::Range { axis_type: AxisType::GroupReduce, .. }));
if reduce_gfr.is_empty() {
return None;
}
let upstream_locals: Vec<Arc<UOp>> = reduce
.toposort()
.into_iter()
.filter(|u| matches!(u.op(), Op::Range { axis_type: AxisType::Local, .. }))
.collect();
let partial_reduce = if reduce_r.is_empty() {
src.clone()
} else {
UOp::new(
Op::Reduce { src: src.clone(), ranges: reduce_r.into_iter().cloned().collect(), reduce_op: *reduce_op },
reduce.dtype(),
)
};
let reduce_loops: Vec<Arc<UOp>> = reduce_gfr
.iter()
.filter_map(|r| {
let Op::Range { end, axis_id, .. } = r.op() else { return None };
Some(UOp::range_axis(end.clone(), AxisId::Renumbered(axis_id.value() + 100), AxisType::Reduce))
})
.collect();
let buf_ranges: Vec<Arc<UOp>> =
upstream_locals.iter().cloned().chain(reduce_gfr.iter().map(|r| (*r).clone())).collect();
let buf = UOp::bufferize_local(partial_reduce, buf_ranges);
let idx_ranges: Vec<Arc<UOp>> = upstream_locals.iter().cloned().chain(reduce_loops.iter().cloned()).collect();
let indexed = UOp::index().buffer(buf).indices(idx_ranges).call().ok()?;
Some(indexed.reduce(reduce_loops.into_iter().collect(), *reduce_op))
}
pub fn pm_group_for_reduce() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
reduce @ Reduce(_, ..) => |reduce| fix_group_for_reduce(reduce),
}
}