use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use super::context::RangeifyContext;
use super::indexing::IndexingContext;
use super::kernel::KernelContext;
use morok_ir::shape::Shape;
use morok_ir::{AddrSpace, AxisType, BufferizeOpts, ConstValue, DType, Op, UOp, UOpKey};
use smallvec::{SmallVec, smallvec};
pub struct AddTagsCtx {
pub uop_list: Vec<Arc<UOp>>,
excluded: HashSet<UOpKey>,
}
impl Default for AddTagsCtx {
fn default() -> Self {
Self::new()
}
}
impl AddTagsCtx {
pub fn new() -> Self {
Self { uop_list: Vec::new(), excluded: HashSet::new() }
}
}
fn should_skip_tag(op: &Op) -> bool {
matches!(
op,
Op::Param { .. }
| Op::Const(_)
| Op::Device(_)
| Op::Unique(_)
| Op::DefineVar { .. }
| Op::Bind { .. }
| Op::End { .. }
| Op::Range { .. }
) || op.is_movement()
}
pub fn add_tags_patterns() -> crate::TypedPatternMatcher<AddTagsCtx> {
crate::patterns! {
@context AddTagsCtx;
x => {
if x.tag().is_some() || ctx.excluded.contains(&UOpKey(x.clone())) { return None; }
if let Op::Kernel { ast, .. } = x.op() {
for u in ast.toposort() {
ctx.excluded.insert(UOpKey(u));
}
}
if should_skip_tag(x.op()) { return None; }
if x.dtype().base() == morok_dtype::ScalarDType::Index { return None; }
if matches!(x.op(), Op::MStack { .. } | Op::MSelect { .. })
&& x.op().sources().iter().all(|s| matches!(s.op(), Op::Param { .. }))
{
return None;
}
ctx.uop_list.push(x.clone());
Some(x.with_tag(smallvec![ctx.uop_list.len() - 1]))
},
}
}
pub fn rangeify(
sink: Arc<UOp>,
pcontig_config: Option<&super::kernel::PcontigConfig>,
) -> morok_ir::Result<(Arc<UOp>, RangeifyContext)> {
let result = rangeify_with_map(sink, pcontig_config)?;
Ok((result.sink, result.context))
}
pub struct RangeifyResult {
pub sink: Arc<UOp>,
pub context: RangeifyContext,
pub uop_list: Vec<Arc<UOp>>,
}
#[allow(clippy::mutable_key_type)]
#[tracing::instrument(skip_all)]
pub fn rangeify_with_map(
sink: Arc<UOp>,
pcontig_config: Option<&super::kernel::PcontigConfig>,
) -> morok_ir::Result<RangeifyResult> {
let t_stage = std::time::Instant::now();
let mut tag_ctx = AddTagsCtx::new();
let mut sink = crate::rewrite::graph_rewrite_bottom_up(&add_tags_patterns(), sink, &mut tag_ctx);
let uop_list = tag_ctx.uop_list;
tracing::debug!(
tagged_count = uop_list.len(),
node_count = sink.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"add_tags complete"
);
let t_stage = std::time::Instant::now();
let early_combined = super::patterns::early_rewrites().with_context::<super::patterns::ReplaceContiguousCtx>()
+ super::patterns::replace_contiguous();
let mut contig_ctx = super::patterns::ReplaceContiguousCtx::new();
sink = crate::rewrite::graph_rewrite_bottom_up(&early_combined, sink, &mut contig_ctx);
tracing::debug!(
uop.tree = sink.tree(),
node_count = sink.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"early rewrites + replace contiguous complete"
);
let t_stage = std::time::Instant::now();
let mut split_config = super::kernel::SplitReduceOpConfig::default();
let split_matcher = super::patterns::split_reduceop_patterns();
sink = crate::rewrite::graph_rewrite(&split_matcher, sink, &mut split_config);
tracing::debug!(
uop.tree = sink.tree(),
node_count = sink.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"split reduceops complete"
);
let t_stage = std::time::Instant::now();
let (rangeified, indexing_ctx) = super::indexing::run_rangeify(sink)?;
sink = rangeified;
tracing::debug!(
uop.tree = sink.tree(),
node_count = sink.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"Stage 0: range assignment + apply rangeify complete"
);
{
use super::kernel::PcontigConfig;
let t_stage = std::time::Instant::now();
use std::sync::LazyLock;
static MEGA_PASS: LazyLock<crate::TypedPatternMatcher<PcontigConfig>> = LazyLock::new(|| {
crate::symbolic::symbolic().with_context::<PcontigConfig>()
+ super::patterns::pm_reduce_simplify().with_context()
+ super::patterns::buffer_folding().with_context()
+ super::patterns::dead_axis_removal().with_context()
+ super::patterns::movement_op_patterns().with_context()
+ super::patterns::buffer_removal_with_pcontig()
});
let mega_pass = &*MEGA_PASS;
tracing::debug!(
total_patterns = mega_pass.len(),
wildcard_count = mega_pass.wildcard_count(),
indexed_buckets = mega_pass.indexed_count(),
"mega-pass pattern stats"
);
let mut pcontig = pcontig_config.cloned().unwrap_or_default();
sink = crate::rewrite::graph_rewrite(mega_pass, sink, &mut pcontig);
tracing::debug!(
node_count = sink.node_count(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"mega-pass complete"
);
}
if let Op::Sink { sources } = sink.op() {
let filtered: Vec<Arc<UOp>> = sources
.iter()
.filter(|s| {
let valid_op = matches!(
s.base().op(),
Op::Bufferize { .. } | Op::MStack { .. } | Op::Const(_) | Op::Param { .. } | Op::After { .. }
);
valid_op
})
.cloned()
.collect();
if !filtered.is_empty() && filtered.len() != sources.len() {
tracing::debug!(
original = sources.len(),
filtered = filtered.len(),
"SINK cleanup: removed invalid-type sources after mega-pass"
);
sink = UOp::sink(filtered);
}
}
if let Some(device) = super::patterns::extract_device_from_graph(&sink)
&& let Some(limit) = device.max_buffers()
{
let t_stage = std::time::Instant::now();
let limit_matcher = super::patterns::buffer_limit_patterns(limit);
sink = crate::rewrite::graph_rewrite(&limit_matcher, sink, &mut ());
tracing::debug!(
uop.tree = sink.tree(),
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"Stage 7b: buffer limit enforcement complete"
);
}
let rangeify_ctx = RangeifyContext { range_counter: indexing_ctx.range_counter(), range_map: HashMap::new() };
Ok(RangeifyResult { sink, context: rangeify_ctx, uop_list })
}
pub fn pm_flatten_range() -> &'static crate::TypedPatternMatcher {
crate::cached_patterns! {
r @ End { computation: _, ranges } if !ranges.is_empty() => |r| flatten_range_impl(r),
r @ Reduce { src: _, ranges, reduce_op: _ } if !ranges.is_empty() => |r| flatten_range_impl(r),
r @ Store { index: _, value: _, ranges } if !ranges.is_empty() => |r| flatten_range_impl(r),
}
}
#[derive(Default)]
pub struct SplitRangesContext {
pub marked_ranges: HashMap<u64, i64>,
protected_ranges: HashSet<u64>,
}
pub fn pm_split_ranges() -> crate::TypedPatternMatcher<SplitRangesContext> {
crate::patterns! {
@context SplitRangesContext;
_modop @ Mod(r @ Range { end, axis_id: _, axis_type: _ }, c @ Const(_))
if is_divisible_range_end(end, c) => |r, c| {
mark_range_mod(ctx, r, c);
None },
_store @ Store { index: idx @ Index { buffer: buf, indices: _, gate: _ }, value: _, ranges: _ }
if is_image_dtype(buf) => |idx| {
protect_ranges_for_image(ctx, idx);
None },
sink @ Sink { sources: _ } if !ctx.marked_ranges.is_empty() => |sink| {
do_split_ranges_substitute(ctx, sink)
},
}
}
fn is_image_dtype(buf: &Arc<UOp>) -> bool {
matches!(buf.dtype(), DType::Image { .. })
}
fn protect_ranges_for_image(ctx: &mut SplitRangesContext, idx: &Arc<UOp>) {
for node in idx.toposort() {
if matches!(node.op(), Op::Range { .. }) {
ctx.protected_ranges.insert(node.id);
ctx.marked_ranges.remove(&node.id);
}
}
}
fn const_uop_to_i64(c: &Arc<UOp>) -> Option<i64> {
match c.op() {
Op::Const(cv) => match cv.0 {
ConstValue::Int(v) => Some(v),
ConstValue::UInt(v) => Some(v as i64),
_ => None,
},
_ => None,
}
}
fn is_divisible_range_end(end: &Arc<UOp>, c: &Arc<UOp>) -> bool {
let Some(end_val) = const_uop_to_i64(end) else {
return false;
};
let Some(mod_val) = const_uop_to_i64(c) else {
return false;
};
mod_val > 0 && end_val % mod_val == 0
}
fn mark_range_mod(ctx: &mut SplitRangesContext, r: &Arc<UOp>, c: &Arc<UOp>) {
if ctx.marked_ranges.contains_key(&r.id) || ctx.protected_ranges.contains(&r.id) {
return;
}
if let Some(mod_val) = const_uop_to_i64(c) {
ctx.marked_ranges.insert(r.id, mod_val);
}
}
fn do_split_ranges_substitute(ctx: &mut SplitRangesContext, sink: &Arc<UOp>) -> Option<Arc<UOp>> {
use morok_ir::AxisId;
use morok_ir::rewrite::graph_rewrite_bottom_up;
if ctx.marked_ranges.is_empty() {
return None;
}
let mut subs: HashMap<u64, Arc<UOp>> = HashMap::new();
let topo = sink.toposort();
let mut max_axis_id: usize = 0;
for uop in &topo {
if let Op::Range { axis_id, .. } = uop.op() {
max_axis_id = max_axis_id.max(axis_id.value());
}
}
let mut next_id = max_axis_id + 1;
for uop in &topo {
if ctx.protected_ranges.contains(&uop.id) {
continue;
}
if let Some(&mod_val) = ctx.marked_ranges.get(&uop.id)
&& let Op::Range { end, axis_type, .. } = uop.op()
{
let Some(end_val) = const_uop_to_i64(end) else {
continue;
};
let outer_end = end_val / mod_val;
let outer_range = UOp::range_axis(UOp::index_const(outer_end), AxisId::Renumbered(next_id), *axis_type);
next_id += 1;
let inner_range = UOp::range_axis(UOp::index_const(mod_val), AxisId::Renumbered(next_id), *axis_type);
next_id += 1;
let mod_const = UOp::index_const(mod_val);
let outer_scaled = outer_range.mul(&mod_const);
let combined = outer_scaled.add(&inner_range);
subs.insert(uop.id, combined);
}
}
if subs.is_empty() {
return None;
}
let substitute_pm = crate::patterns! {
r @ Range { end: _, axis_id: _, axis_type: _ } if subs.contains_key(&r.id) => {
subs.get(&r.id).cloned()
},
};
let result = graph_rewrite_bottom_up(&substitute_pm, sink.clone(), &mut ());
ctx.marked_ranges.clear();
Some(result)
}
pub fn transform_sources_with_bufferize(x: &Arc<UOp>, ctx: &mut IndexingContext) -> Option<Vec<Arc<UOp>>> {
if matches!(x.op(), Op::Bufferize { .. } | Op::Index { .. } | Op::After { .. }) {
return None;
}
let sources = x.op().sources();
if sources.is_empty() {
return None;
}
let input_ranges = if let Some((ranges, _)) = ctx.get_ranges(x) { ranges.clone() } else { Vec::new() };
let mut new_sources = Vec::with_capacity(sources.len());
let mut any_changed = false;
for src in sources.iter() {
let new_src = transform_single_source(x, src, &input_ranges, ctx);
if !Arc::ptr_eq(&new_src, src) {
any_changed = true;
}
new_sources.push(new_src);
}
if any_changed { Some(new_sources) } else { None }
}
fn flatten_bufferize(bufferize: &Arc<UOp>) -> Option<Arc<UOp>> {
let Op::Bufferize { compute, ranges, opts } = bufferize.op() else { return None };
if ranges.len() <= 1 {
return None;
}
let shape: Vec<morok_ir::SInt> = ranges
.iter()
.map(|r| match r.op() {
Op::Range { end, .. } => morok_ir::SInt::from(end.clone()),
_ => morok_ir::SInt::from(1usize),
})
.collect();
let flat_shape = vec![morok_ir::sint_prod(&shape)];
let ranges_vec: Vec<Arc<UOp>> = ranges.iter().cloned().collect();
let flat_indices = super::indexing::apply_reshape_ranges(&flat_shape, &shape, &ranges_vec);
assert_eq!(flat_indices.len(), 1, "flatten_bufferize: expected 1 flat index, got {}", flat_indices.len());
let flat_buf = UOp::bufferize(compute.clone(), vec![flat_indices[0].clone()], opts.clone());
let shape_smallvec: Shape = shape.iter().cloned().collect();
let reshaped = flat_buf.try_reshape(&shape_smallvec).expect("flatten_bufferize: try_reshape failed");
let has_symbolic =
ranges.iter().any(|r| matches!(r.op(), Op::Range { end, .. } if !matches!(end.op(), Op::Const(_))));
if has_symbolic {
let sym_ranges: Vec<(morok_ir::SInt, morok_ir::SInt)> = ranges
.iter()
.map(|r| match r.op() {
Op::Range { end, .. } => (morok_ir::SInt::from(0usize), morok_ir::SInt::from(end.clone())),
_ => (morok_ir::SInt::from(0usize), morok_ir::SInt::from(1usize)),
})
.collect();
Some(reshaped.try_shrink(&sym_ranges).expect("flatten_bufferize: try_shrink failed for symbolic ranges"))
} else {
Some(reshaped)
}
}
pub(crate) fn push_movement_through_after(mop: &Arc<UOp>, deps: &SmallVec<[Arc<UOp>; 4]>) -> Option<Arc<UOp>> {
let inner_src = &mop.op().sources()[0];
let new_after = inner_src.after(deps.clone());
let new_op = match mop.op() {
Op::Reshape { new_shape, .. } => Op::Reshape { src: new_after, new_shape: new_shape.clone() },
Op::Permute { axes, .. } => Op::Permute { src: new_after, axes: axes.clone() },
Op::Expand { new_shape, .. } => Op::Expand { src: new_after, new_shape: new_shape.clone() },
Op::Pad { begin_pads, end_pads, .. } => {
Op::Pad { src: new_after, begin_pads: begin_pads.clone(), end_pads: end_pads.clone() }
}
Op::Shrink { begins, ends, .. } => Op::Shrink { src: new_after, begins: begins.clone(), ends: ends.clone() },
Op::Flip { axes, .. } => Op::Flip { src: new_after, axes: axes.clone() },
_ => return None,
};
Some(UOp::new(new_op, mop.dtype()))
}
pub(crate) fn transform_single_source(
consumer: &Arc<UOp>,
src: &Arc<UOp>,
input_ranges: &[Arc<UOp>],
ctx: &mut IndexingContext,
) -> Arc<UOp> {
if matches!(
src.op(),
Op::Buffer { .. }
| Op::Param { .. }
| Op::BufferView { .. }
| Op::MStack { .. }
| Op::MSelect { .. }
| Op::After { .. }
) {
if !input_ranges.is_empty() {
return UOp::index()
.buffer(Arc::clone(src))
.indices(input_ranges.to_vec())
.call()
.expect("Failed to create INDEX for buffer source");
}
return Arc::clone(src);
}
let realize_axes_opt = ctx.get_realize_axes(src).cloned();
if let Some(ref realize_axes) = realize_axes_opt {
let (_, output_ranges) = ctx.get_ranges(src).expect("Realized op must have ranges");
let closed_ranges: Vec<_> = output_ranges
.iter()
.enumerate()
.filter(|(i, _)| realize_axes.contains(i))
.map(|(_, r)| Arc::clone(r))
.collect();
let is_copy_consumer = matches!(consumer.op(), Op::Copy { .. });
let is_always_contiguous_src = super::indexing::is_always_contiguous(src);
let removable = !is_copy_consumer && !is_always_contiguous_src;
let addrspace = if output_ranges.len() == realize_axes.len() { AddrSpace::Global } else { AddrSpace::Local };
tracing::debug!(
src_id = src.id,
src_op = src.op().as_ref(),
consumer_id = consumer.id,
consumer_op = consumer.op().as_ref(),
realize_axes = ?realize_axes,
output_ranges_len = output_ranges.len(),
addrspace = ?addrspace,
removable = removable,
"BUFFERIZE decision"
);
let device = src.device_spec();
let opts = BufferizeOpts { device, addrspace, removable };
let buf_tag = if addrspace == AddrSpace::Global { src.tag().clone() } else { None };
let bufferized = UOp::bufferize(Arc::clone(src), closed_ranges.clone(), opts);
let bufferized = if let Some(t) = buf_tag { bufferized.with_tag(t) } else { bufferized };
let index_ranges: Vec<_> = input_ranges
.iter()
.enumerate()
.filter(|(i, _)| realize_axes.contains(i))
.map(|(_, r)| Arc::clone(r))
.collect();
if !index_ranges.is_empty() {
return UOp::index()
.buffer(bufferized)
.indices(index_ranges)
.call()
.expect("Failed to create INDEX after BUFFERIZE");
} else {
return bufferized;
}
}
Arc::clone(src)
}
fn apply_movement_ops_chain(result: &Arc<UOp>, chain: &Arc<UOp>) -> Option<Arc<UOp>> {
let mut mops = Vec::new();
let mut walk = chain.clone();
while walk.op().is_movement() {
mops.push(walk.clone());
walk = match walk.op() {
Op::Reshape { src, .. }
| Op::Permute { src, .. }
| Op::Expand { src, .. }
| Op::Pad { src, .. }
| Op::Shrink { src, .. }
| Op::Flip { src, .. } => src.clone(),
_ => break,
};
}
let mut current = result.clone();
for mop in mops.into_iter().rev() {
current = apply_single_movement_op(¤t, mop.op())?;
}
Some(current)
}
fn apply_single_movement_op(uop: &Arc<UOp>, op: &Op) -> Option<Arc<UOp>> {
match op {
Op::Reshape { new_shape, .. } => {
let shape = extract_shape_from_uop(new_shape)?;
uop.try_reshape(&shape).ok()
}
Op::Permute { axes, .. } => uop.try_permute(axes.clone()).ok(),
Op::Expand { new_shape, .. } => {
let shape = extract_shape_from_uop(new_shape)?;
uop.try_expand(&shape).ok()
}
Op::Pad { begin_pads, end_pads, .. } => {
let begins = extract_shape_from_uop(begin_pads)?;
let ends = extract_shape_from_uop(end_pads)?;
let padding: Vec<_> = begins.into_iter().zip(ends).collect();
uop.try_pad(&padding).ok()
}
Op::Shrink { begins, ends, .. } => {
let begin_shape = extract_shape_from_uop(begins)?;
let end_shape = extract_shape_from_uop(ends)?;
let ranges: Vec<_> = begin_shape.into_iter().zip(end_shape).collect();
uop.try_shrink(&ranges).ok()
}
Op::Flip { axes, .. } => uop.try_flip(axes.clone()).ok(),
_ => None,
}
}
fn extract_shape_from_uop(shape_uop: &Arc<UOp>) -> Option<Shape> {
use morok_ir::SInt;
match shape_uop.op() {
Op::Vectorize { elements } => Some(elements.iter().cloned().map(SInt::from).collect()),
Op::Const(const_hash) => match const_hash.0 {
ConstValue::Int(v) if v >= 0 => Some(smallvec![SInt::from(v as usize)]),
ConstValue::UInt(v) => Some(smallvec![SInt::from(v as usize)]),
_ => None,
},
Op::VConst { values } => {
let mut dims = smallvec![];
for val in values {
match val {
ConstValue::Int(v) if *v >= 0 => dims.push(SInt::from(*v as usize)),
ConstValue::UInt(v) => dims.push(SInt::from(*v as usize)),
_ => return None,
}
}
Some(dims)
}
_ => None,
}
}
fn create_loop_range_from_outer(outer: &Arc<UOp>, size: usize) -> Option<Arc<UOp>> {
use morok_ir::AxisType;
let Op::Range { axis_id, .. } = outer.op() else {
return None;
};
Some(UOp::range_axis(UOp::index_const(size as i64), *axis_id, AxisType::Loop))
}
fn reduce_op_to_binary(op: morok_ir::ReduceOp, lhs: &Arc<UOp>, rhs: &Arc<UOp>) -> Option<Arc<UOp>> {
use morok_ir::types::{BinaryOp, ReduceOp};
let dtype = lhs.dtype();
Some(match op {
ReduceOp::Add => UOp::new(Op::Binary(BinaryOp::Add, lhs.clone(), rhs.clone()), dtype),
ReduceOp::Mul => UOp::new(Op::Binary(BinaryOp::Mul, lhs.clone(), rhs.clone()), dtype),
ReduceOp::Max => UOp::new(Op::Binary(BinaryOp::Max, lhs.clone(), rhs.clone()), dtype),
ReduceOp::Min => {
let cond = UOp::new(Op::Binary(BinaryOp::Lt, lhs.clone(), rhs.clone()), morok_dtype::DType::Bool);
UOp::try_where(cond, lhs.clone(), rhs.clone()).expect("reduce_op_to_binary: try_where failed for Min")
}
})
}
fn calculate_size_from_ranges(ranges: &SmallVec<[Arc<UOp>; 4]>) -> usize {
if ranges.is_empty() {
return 1;
}
ranges
.iter()
.map(|r| {
let vmax = r.vmax();
match vmax {
ConstValue::Int(v) if *v >= 0 => (*v + 1) as usize,
ConstValue::UInt(v) => (*v + 1) as usize,
other => panic!(
"Cannot allocate buffer: range vmax resolved to {:?}. \
Buffers require concrete sizes (Tinygrad: 'no symbolic sized buffers')",
other
),
}
})
.product()
}
fn sort_ranges_by_axis_id(ranges: &SmallVec<[Arc<UOp>; 4]>) -> SmallVec<[Arc<UOp>; 4]> {
let mut sorted: Vec<_> = ranges.iter().cloned().collect();
sorted.sort_by_key(|r| {
if let Op::Range { axis_id, axis_type, .. } = r.op() {
(axis_id.value(), axis_type_ordinal(*axis_type))
} else {
(usize::MAX, u8::MAX)
}
});
sorted.into()
}
fn axis_type_ordinal(at: AxisType) -> u8 {
match at {
AxisType::Outer => 0,
AxisType::Global => 1,
AxisType::Warp => 2,
AxisType::Local => 3,
AxisType::Loop => 4,
AxisType::GroupReduce => 5,
AxisType::Reduce => 6,
AxisType::Upcast => 7,
AxisType::Unroll => 8,
AxisType::Thread => 9,
AxisType::Placeholder => 10,
}
}
fn collect_range_uops(ranges: &SmallVec<[Arc<UOp>; 4]>) -> SmallVec<[Arc<UOp>; 4]> {
let mut collected = SmallVec::new();
for r in ranges.iter() {
if matches!(r.op(), Op::Range { .. }) {
collected.push(r.clone());
} else if !matches!(r.op(), Op::Const(_)) {
for rng in r.ranges().iter() {
if !collected.iter().any(|c: &Arc<UOp>| c.id == rng.id) {
collected.push(rng.clone());
}
}
}
}
collected
}
pub fn bufferize_to_store(bufferize_op: &Arc<UOp>, ctx: &mut KernelContext, allow_locals: bool) -> Option<Arc<UOp>> {
let (compute, ranges, opts) = match bufferize_op.op() {
Op::Bufferize { compute, ranges, opts } => {
tracing::debug!(
bufferize_id = bufferize_op.id,
compute_id = compute.id,
ranges_len = ranges.len(),
allow_locals = allow_locals,
"bufferize_to_store: CONVERTING BUFFERIZE to STORE→AFTER"
);
(compute, ranges, opts)
}
_ => return None,
};
let size = calculate_size_from_ranges(ranges);
let base_dtype = match bufferize_op.dtype() {
DType::Ptr { base, .. } => (*base).clone(),
other => other,
};
let sdtype = base_dtype.clone().ptr(Some(size), opts.addrspace);
let end_ranges: SmallVec<[Arc<UOp>; 4]> = sort_ranges_by_axis_id(&collect_range_uops(ranges));
if let Op::Assign { target, value, movement_ops } = compute.op() {
let Op::Index { buffer, indices, gate } = target.op() else {
return None;
};
let store_target = UOp::index()
.buffer(buffer.clone())
.indices(indices.to_vec())
.maybe_gate(gate.clone())
.dtype(sdtype.clone())
.call()
.expect("bufferize_to_store: failed to create INDEX for ASSIGN target");
let store = store_target.store_value(value.clone());
let do_store = if end_ranges.is_empty() { store } else { store.end(end_ranges.clone()) };
let mut result = buffer.after(smallvec![do_store]);
if let Some(mops_chain) = movement_ops {
result = apply_movement_ops_chain(&result, mops_chain)?;
}
ctx.map_buffer(bufferize_op.clone(), result.clone());
return Some(result);
}
if let Op::Reduce { src: reduce_src, ranges: reduce_ranges, reduce_op } = compute.op() {
if reduce_ranges.len() == 1
&& let Op::Range { axis_type, .. } = reduce_ranges[0].op()
&& *axis_type == AxisType::Outer
{
if opts.addrspace != AddrSpace::Global {
return None;
}
let outer_range = reduce_ranges[0].clone();
let device = opts.device.clone().unwrap_or(morok_ir::DeviceSpec::Cpu);
let buf = UOp::new_buffer(device, size, base_dtype.clone());
let zero_range = create_loop_range_from_outer(&outer_range, size)?;
use crate::symbolic::dce::reduce_identity;
let identity = reduce_identity(*reduce_op, base_dtype.clone());
let zero_idx = UOp::index()
.buffer(buf.clone())
.indices(vec![zero_range.clone()])
.dtype(sdtype.clone())
.call()
.expect("bufferize_to_store: failed to create INDEX for OUTER REDUCE zero-init");
let zero_store = zero_idx.store_value(identity).end(smallvec![zero_range.clone()]);
let buf_zeroed = buf.after(smallvec![zero_store]);
debug_assert!(
ranges.len() <= 1 || ranges.iter().all(|r| matches!(r.op(), Op::Const(_))),
"bufferize_to_store: unexpected multi-range in OUTER REDUCE after flatten_bufferize"
);
let idx = if ranges.len() == 1 && !matches!(ranges[0].op(), Op::Const(_)) {
ranges[0].clone()
} else if !end_ranges.is_empty() {
sort_ranges_by_axis_id(&end_ranges)[0].clone()
} else {
UOp::index_const(0)
};
let sorted_end_ranges = sort_ranges_by_axis_id(&collect_range_uops(ranges));
let buf_idx = UOp::index()
.buffer(buf_zeroed.clone())
.indices(vec![idx])
.dtype(sdtype.clone())
.call()
.expect("bufferize_to_store: failed to create INDEX for OUTER REDUCE accumulation");
let loaded = UOp::load().buffer(buf_zeroed.clone()).index(buf_idx.clone()).call();
let accumulated = reduce_op_to_binary(*reduce_op, &loaded, reduce_src)?;
let do_store = buf_idx.store_value(accumulated).end(sorted_end_ranges).end(smallvec![outer_range]);
let result = buf_zeroed.after(smallvec![do_store]);
ctx.map_buffer(bufferize_op.clone(), result.clone());
return Some(result);
}
}
if !allow_locals && opts.addrspace == AddrSpace::Local {
return None;
}
let effective_addrspace = opts.addrspace;
let buffer = if let Some(existing_buffer) = ctx.get_buffer(bufferize_op) {
existing_buffer.clone()
} else if effective_addrspace == AddrSpace::Global {
let device = opts.device.clone().unwrap_or(morok_ir::DeviceSpec::Cpu);
UOp::new_buffer(device, size, base_dtype.clone())
} else {
let local_ptr_dtype = base_dtype.clone().ptr(Some(size), opts.addrspace);
let local_id = ctx.next_local();
UOp::define_local(local_id, local_ptr_dtype)
};
let active_ranges: SmallVec<[Arc<UOp>; 4]> = collect_range_uops(ranges);
let sorted_ranges = sort_ranges_by_axis_id(&active_ranges);
let vcount = compute.dtype().vcount();
let store_buffer = if vcount > 1 { buffer.broadcast(vcount) } else { buffer.clone() };
let store_target = if !sorted_ranges.is_empty() {
assert!(
ranges.len() <= 1 || ranges.iter().all(|r| matches!(r.op(), Op::Const(_))),
"bufferize_to_store: unexpected multi-range in general path after flatten_bufferize"
);
let idx = if ranges.len() == 1 && !matches!(ranges[0].op(), Op::Const(_)) {
ranges[0].clone()
} else {
sorted_ranges[0].clone()
};
UOp::index()
.buffer(store_buffer)
.indices(vec![idx])
.dtype(sdtype.clone())
.call()
.expect("Failed to create INDEX for BUFFERIZE-to-STORE conversion")
} else {
UOp::index()
.buffer(store_buffer)
.indices(vec![UOp::index_const(0)])
.dtype(sdtype.clone())
.call()
.expect("Failed to create INDEX for scalar STORE")
};
let store = store_target.store_value(compute.clone());
let end_ranges: SmallVec<[Arc<UOp>; 4]> = sorted_ranges.clone();
let mut do_store = if !end_ranges.is_empty() { store.end(end_ranges) } else { store };
if opts.addrspace == AddrSpace::Local {
do_store = do_store.barrier(SmallVec::new());
}
let result = buffer.after(SmallVec::from_elem(do_store, 1));
ctx.map_buffer(bufferize_op.clone(), result.clone());
Some(result)
}
#[allow(clippy::mutable_key_type)]
pub(crate) fn partition_reduce_ranges(
ranges: &SmallVec<[Arc<UOp>; 4]>,
src_ranges: &HashSet<UOpKey>,
) -> (SmallVec<[Arc<UOp>; 4]>, Vec<Arc<UOp>>) {
let mut parented = SmallVec::new();
let mut unparented = Vec::new();
for range in ranges {
let key = UOpKey(Arc::clone(range));
if src_ranges.contains(&key) {
parented.push(Arc::clone(range));
} else {
unparented.push(Arc::clone(range));
}
}
(parented, unparented)
}
pub(crate) fn get_range_size(range: &Arc<UOp>) -> Option<Arc<UOp>> {
if let Op::Range { end, .. } = range.op() { Some(Arc::clone(end)) } else { None }
}
#[allow(clippy::mutable_key_type)]
fn reduce_collapse_with(src: &Arc<UOp>, ranges: &[Arc<UOp>], pm: &crate::TypedPatternMatcher<()>) -> Option<Arc<UOp>> {
use morok_ir::ReduceOp;
if ranges.is_empty() {
return None;
}
let mut u = Arc::clone(src);
for range in ranges {
let range_key = UOpKey(range.clone());
let in_scope: HashSet<UOpKey> =
u.toposort_filtered(|node| node.in_scope_ranges().contains(&range_key)).into_iter().map(UOpKey).collect();
if in_scope.iter().any(|k| matches!(k.0.op(), Op::Reduce { .. } | Op::Store { .. })) {
return None;
}
let mut replaces: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
for node in &in_scope {
node.0.op().map_child(|child| {
let key = UOpKey(child.clone());
if in_scope.contains(&key) || replaces.contains_key(&key) {
return;
}
if matches!(
child.op(),
Op::Const(_)
| Op::VConst { .. }
| Op::DefineVar { .. }
| Op::Param { device: None, .. }
| Op::DefineLocal { .. }
) {
return;
}
let vmin = match child.vmin() {
ConstValue::Int(i) => *i,
ConstValue::UInt(u) => *u as i64,
ConstValue::Float(f) => *f as i64,
ConstValue::Bool(b) => *b as i64,
};
let vmax = match child.vmax() {
ConstValue::Int(i) => *i,
ConstValue::UInt(u) => *u as i64,
ConstValue::Float(f) => *f as i64,
ConstValue::Bool(b) => *b as i64,
};
let var = UOp::define_var(format!("in{}", replaces.len()), vmin, vmax).with_dtype(child.dtype());
replaces.insert(key, var);
});
}
let substituted = u.substitute(&replaces);
let synthetic_reduce = substituted.reduce(smallvec![range.clone()], ReduceOp::Add);
let result = crate::rewrite::graph_rewrite(pm, synthetic_reduce, &mut ());
let has_range = result.toposort().iter().any(|x| matches!(x.op(), Op::Range { .. }));
if has_range {
return None;
}
let reverse: HashMap<UOpKey, Arc<UOp>> = replaces.into_iter().map(|(k, v)| (UOpKey(v), k.0)).collect();
u = result.substitute(&reverse);
}
Some(u)
}
pub fn reduce_collapse(src: &Arc<UOp>, ranges: &[Arc<UOp>]) -> Option<Arc<UOp>> {
reduce_collapse_with(src, ranges, super::patterns::build_reduce_collapse_matcher())
}
pub fn reduce_load_collapse(src: &Arc<UOp>, ranges: &[Arc<UOp>]) -> Option<Arc<UOp>> {
reduce_collapse_with(src, ranges, super::patterns::build_reduce_load_collapse_matcher())
}
pub(crate) fn cast_to_dtype(value: &Arc<UOp>, target_dtype: &morok_dtype::DType) -> Option<Arc<UOp>> {
use morok_dtype::DType;
let scalar_type = match target_dtype {
DType::Scalar(s) => DType::Scalar(*s),
DType::Vector { scalar, .. } => DType::Scalar(*scalar),
_ => return None,
};
let casted = value.cast(scalar_type);
if target_dtype.is_vector() {
let count = target_dtype.count();
let elements: SmallVec<[Arc<UOp>; 4]> = (0..count).map(|_| casted.clone()).collect();
Some(UOp::vectorize(elements))
} else {
Some(casted)
}
}
pub fn simplify_merge_adjacent(u: &Arc<UOp>) -> Option<Arc<UOp>> {
use crate::passes::linearize_index::count_divmod;
let ended_ranges = match u.op() {
Op::End { computation: _, ranges } => ranges.clone(),
Op::Reduce { ranges, .. } => ranges.clone(),
_ => return None,
};
if ended_ranges.len() < 2 {
return None;
}
let reduce_ranges: Vec<SmallVec<[Arc<UOp>; 4]>> = u
.toposort()
.iter()
.filter_map(|dep| match dep.op() {
Op::Reduce { ranges, .. } => Some(ranges.clone()),
_ => None,
})
.collect();
let mut current = Arc::clone(u);
let mut changed = false;
let pairs: Vec<(usize, usize)> = if matches!(u.op(), Op::End { .. }) {
(0..ended_ranges.len() - 1).map(|i| (i, i + 1)).collect()
} else {
let mut perms = Vec::new();
for i in 0..ended_ranges.len() {
for j in 0..ended_ranges.len() {
if i != j {
perms.push((i, j));
}
}
}
perms
};
for (i0, i1) in pairs {
let r0 = &ended_ranges[i0];
let r1 = &ended_ranges[i1];
let (r0_axis_type, r0_end) = match r0.op() {
Op::Range { end, axis_type, .. } => (axis_type, end),
_ => continue,
};
let (r1_axis_type, r1_end) = match r1.op() {
Op::Range { end, axis_type, .. } => (axis_type, end),
_ => continue,
};
if r0_axis_type != r1_axis_type {
continue;
}
let valid_reduce_scope = reduce_ranges.iter().all(|rngs| {
let r0_in = rngs.iter().any(|rng| Arc::ptr_eq(rng, r0));
let r1_in = rngs.iter().any(|rng| Arc::ptr_eq(rng, r1));
r0_in == r1_in
});
if !valid_reduce_scope {
continue;
}
if let Some(v) = const_uop_to_i64(r0_end)
&& v <= 0
{
continue;
}
if let Some(v) = const_uop_to_i64(r1_end)
&& v <= 0
{
continue;
}
if let (Some(s0), Some(s1)) = (const_uop_to_i64(r0_end), const_uop_to_i64(r1_end))
&& s0.checked_mul(s1).is_none()
{
continue;
}
let merged_size_uop = r0_end.mul(r1_end);
let merged_range = r0.with_sources(vec![merged_size_uop]);
let new_r0 = merged_range.idiv(r1_end);
let new_r1 = merged_range.mod_(r1_end);
#[allow(clippy::mutable_key_type)]
let mut subs: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
subs.insert(UOpKey(r0.clone()), new_r0);
subs.insert(UOpKey(r1.clone()), new_r1);
let rewritten = current.substitute(&subs);
static MERGE_SYM: std::sync::LazyLock<crate::TypedPatternMatcher> =
std::sync::LazyLock::new(|| crate::symbolic::symbolic().clone() + pm_flatten_range().clone());
let simplified = crate::rewrite::graph_rewrite(&*MERGE_SYM, rewritten, &mut ());
let original_divmod = count_divmod(¤t);
let new_divmod = count_divmod(&simplified);
if new_divmod <= original_divmod {
current = simplified;
changed = true;
}
}
if changed { Some(current) } else { None }
}
pub fn pm_simplify_ranges() -> &'static crate::TypedPatternMatcher {
crate::cached_patterns! {
u @ End { computation: _, ranges } if !ranges.is_empty() => |u| simplify_merge_adjacent(u),
u @ Reduce { src: _, ranges, reduce_op: _ } if !ranges.is_empty() => |u| simplify_merge_adjacent(u),
}
}
pub fn flatten_range_impl(r: &Arc<UOp>) -> Option<Arc<UOp>> {
let off = match r.op() {
Op::Reduce { .. } => 1,
Op::Store { .. } => 2, Op::End { .. } => 1,
_ => return None,
};
let original_sources = r.op().sources();
let original_ranges: Vec<&Arc<UOp>> = original_sources.iter().skip(off).collect();
let mut all_range_sources: Vec<Arc<UOp>> = original_ranges.iter().map(|r| (*r).clone()).collect();
let innermost_computation = if matches!(r.op(), Op::End { .. }) {
let mut computation = Arc::clone(&original_sources[0]);
while matches!(computation.op(), Op::End { .. }) {
all_range_sources.extend(computation.op().sources().iter().skip(1).cloned());
computation = Arc::clone(&computation.op().sources()[0]);
}
Some(computation)
} else {
None
};
if all_range_sources.is_empty() {
return None;
}
let sink = UOp::sink(all_range_sources);
let new_ranges: Vec<Arc<UOp>> =
sink.toposort().into_iter().filter(|uop| matches!(uop.op(), Op::Range { .. })).collect();
if new_ranges.is_empty() {
return None;
}
if new_ranges.len() == original_ranges.len()
&& innermost_computation.as_ref().is_none_or(|c| Arc::ptr_eq(c, &original_sources[0]))
&& new_ranges.iter().zip(original_ranges.iter()).all(|(a, b)| Arc::ptr_eq(a, *b))
{
return None; }
let mut new_sources: Vec<Arc<UOp>> =
if let Some(inner_comp) = innermost_computation { vec![inner_comp] } else { original_sources[..off].to_vec() };
new_sources.extend(new_ranges);
Some(r.with_sources(new_sources))
}
#[allow(clippy::mutable_key_type)]
pub fn flatten_ranges(root: &Arc<UOp>) -> Arc<UOp> {
let mut replacements: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
for node in root.toposort() {
if let Some(flattened) = flatten_range_impl(&node) {
replacements.insert(UOpKey(node.clone()), flattened);
}
}
root.substitute(&replacements)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum OpAccessType {
Load,
Store,
}
pub fn as_buf(uop: &Arc<UOp>) -> Arc<UOp> {
match uop.op() {
Op::MSelect { buffer, .. } => buffer.clone(),
Op::MStack { buffers } if !buffers.is_empty() => buffers[0].clone(),
Op::After { passthrough, .. } => passthrough.clone(),
_ => uop.clone(),
}
}
#[allow(clippy::mutable_key_type)]
pub fn find_bufs(store: &Arc<UOp>) -> HashMap<UOpKey, OpAccessType> {
let mut ret: HashMap<UOpKey, OpAccessType> = HashMap::new();
let nodes = store.toposort_filtered(|uop| !matches!(uop.op(), Op::After { .. }));
for node in nodes {
if let Op::Load { buffer, .. } = node.op() {
let buf = as_buf(buffer);
let buf_key = UOpKey(buf.clone());
if let Some(&existing_access) = ret.get(&buf_key)
&& existing_access != OpAccessType::Load
{
panic!(
"buffer accessed with conflicting ops: {:?} (existing: {:?}, new: {:?})",
buf,
existing_access,
OpAccessType::Load
);
}
ret.insert(buf_key, OpAccessType::Load);
}
if let Some(buffer) = node.store_buffer() {
let buf = as_buf(buffer);
let buf_key = UOpKey(buf.clone());
if let Some(&existing_access) = ret.get(&buf_key)
&& existing_access != OpAccessType::Store
{
panic!(
"buffer accessed with conflicting ops: {:?} (existing: {:?}, new: {:?})",
buf,
existing_access,
OpAccessType::Store
);
}
ret.insert(buf_key, OpAccessType::Store);
}
}
ret
}
fn late_buffer_view(compute: &Arc<UOp>, bufferize: &Arc<UOp>) -> Option<Arc<UOp>> {
use morok_ir::uop::cached_property::CachedProperty;
use morok_ir::uop::properties::VminVmaxProperty;
let Op::Bufferize { opts, ranges, .. } = bufferize.op() else { return None };
if !matches!(&opts.device, Some(d) if d.is_disk()) {
return None;
}
let size: usize = ranges
.iter()
.map(|r| {
if let Op::Range { end, .. } = r.op()
&& let (_, morok_ir::ConstValue::Int(v)) = VminVmaxProperty::get(end)
{
return *v as usize;
}
if let Op::Const(_) = r.op() {
return 1; }
1
})
.product();
let mut x = compute.clone();
loop {
if x.op().sources().iter().any(|s| matches!(s.op(), Op::Index { .. })) {
break;
}
if matches!(x.op(), Op::BitCast { .. } | Op::Contiguous { .. }) {
x = x.op().sources().first()?.clone();
continue;
}
if matches!(x.op(), Op::Unary(..) | Op::Binary(..) | Op::Ternary(..) | Op::Cast { .. }) {
return None;
}
x = x.op().sources().first()?.clone();
}
let index = x.op().sources().iter().find(|s| matches!(s.op(), Op::Index { .. }))?.clone();
let offset: usize = if let Op::Index { indices, .. } = index.op() {
if indices.is_empty() {
0
} else {
let mut total: i64 = 0;
for idx in indices.iter() {
let (vmin, _) = VminVmaxProperty::get(idx);
if let morok_ir::ConstValue::Int(v) = vmin {
total += v;
}
}
total.max(0) as usize
}
} else {
0
};
let base = index.base();
let buffer_view = UOp::new(Op::BufferView { buffer: base, size, offset }, compute.dtype());
let new_sources: Vec<Arc<UOp>> = std::iter::once(buffer_view).chain(ranges.iter().cloned()).collect();
Some(UOp::bufferize(new_sources[0].clone(), new_sources[1..].to_vec(), opts.clone()))
}
pub fn pm_add_buffers_patterns() -> crate::TypedPatternMatcher<super::kernel::KernelContext> {
crate::patterns! {
@context super::kernel::KernelContext;
buf @ Bufferize { compute: _ } if matches!(buf.op(), Op::Bufferize { ranges, .. } if ranges.len() > 1)
=> |buf, _ctx| { flatten_bufferize(buf) },
Index { buffer: mop, indices, gate } if mop.op().is_movement()
=> |mop, indices, gate, _ctx| {
super::patterns::transform_movement_through_index(mop, indices, gate)
},
After { passthrough: mop, deps } if mop.op().is_movement()
=> |mop, deps, _ctx| {
push_movement_through_after(mop, deps)
},
End { computation: mop, ranges } if mop.op().is_movement()
=> |mop, ranges, _ctx| {
let src = &mop.op().sources()[0];
Some(src.end(ranges.clone()))
},
buf @ Bufferize { compute }
if matches!(compute.op(), Op::BitCast { .. } | Op::Contiguous { .. })
=> |buf, compute, _ctx| late_buffer_view(compute, buf),
buf @ Bufferize { compute: _ } => |buf, ctx| {
bufferize_to_store(buf, ctx, false)
},
}
}
pub fn pm_add_buffers_local_patterns() -> crate::TypedPatternMatcher<super::kernel::KernelContext> {
crate::patterns! {
@context super::kernel::KernelContext;
buf @ Bufferize { compute: _ } if matches!(buf.op(), Op::Bufferize { ranges, .. } if ranges.len() > 1)
=> |buf, _ctx| { flatten_bufferize(buf) },
Index { buffer: mop, indices, gate } if mop.op().is_movement()
=> |mop, indices, gate, _ctx| {
super::patterns::transform_movement_through_index(mop, indices, gate)
},
After { passthrough: mop, deps } if mop.op().is_movement()
=> |mop, deps, _ctx| {
push_movement_through_after(mop, deps)
},
End { computation: mop, ranges } if mop.op().is_movement()
=> |mop, ranges, _ctx| {
let src = &mop.op().sources()[0];
Some(src.end(ranges.clone()))
},
buf @ Bufferize { compute }
if matches!(compute.op(), Op::BitCast { .. } | Op::Contiguous { .. })
=> |buf, compute, _ctx| late_buffer_view(compute, buf),
buf @ Bufferize { compute: _ } => |buf, ctx| {
bufferize_to_store(buf, ctx, true)
},
}
}