use std::collections::HashMap;
use std::sync::Arc;
use morok_dtype::DType;
use morok_ir::{AxisId, AxisType, BinaryOp, ConstValue, Op, SInt, UOp, UOpKey};
use tracing::{debug, info_span, instrument, trace, warn};
use crate::argsort;
type UOpRanges = (Vec<Arc<UOp>>, Vec<Arc<UOp>>);
#[derive(Default)]
pub struct IndexingContext {
pub realize_map: HashMap<UOpKey, Option<Vec<usize>>>,
pub range_map: HashMap<UOpKey, UOpRanges>,
range_idx: usize,
}
impl IndexingContext {
pub fn new() -> Self {
Self::default()
}
pub fn new_range(&mut self, size: &SInt, axistype: AxisType) -> Arc<UOp> {
if let SInt::Symbolic(u) = size
&& matches!(u.op(), Op::Range { .. })
{
return Arc::clone(u);
}
if let SInt::Const(1) = size {
return UOp::index_const(0);
}
self.new_range_uncollapsed(size, axistype)
}
pub fn new_range_uncollapsed(&mut self, size: &SInt, axistype: AxisType) -> Arc<UOp> {
let axis_id = AxisId::Unrenumbered(self.range_idx);
self.range_idx += 1;
let size_uop = size.to_uop(morok_dtype::DType::Index);
UOp::range_axis(size_uop, axis_id, axistype)
}
pub fn new_range_from_uop(&mut self, end: &Arc<UOp>, axis_type: AxisType) -> Arc<UOp> {
let axis_id = AxisId::Unrenumbered(self.range_idx);
self.range_idx += 1;
UOp::range_axis(Arc::clone(end), axis_id, axis_type)
}
pub fn mark_realize_all(&mut self, uop: &Arc<UOp>) -> morok_ir::Result<()> {
if let Some(shape) = uop.shape()? {
let axes = (0..shape.len()).collect();
self.realize_map.insert(UOpKey(Arc::clone(uop)), Some(axes));
}
Ok(())
}
pub fn mark_realize(&mut self, uop: &Arc<UOp>, axes: Vec<usize>) {
self.realize_map.insert(UOpKey(Arc::clone(uop)), Some(axes));
}
pub fn should_realize(&self, uop: &Arc<UOp>) -> bool {
self.realize_map.contains_key(&UOpKey(Arc::clone(uop)))
}
pub fn get_realize_axes(&self, uop: &Arc<UOp>) -> Option<&Vec<usize>> {
self.realize_map.get(&UOpKey(Arc::clone(uop))).and_then(|opt| opt.as_ref())
}
#[allow(dead_code)]
pub fn realize_map_keys(&self) -> Vec<&UOpKey> {
self.realize_map.keys().collect()
}
pub fn set_ranges(&mut self, uop: &Arc<UOp>, input_ranges: Vec<Arc<UOp>>, output_ranges: Vec<Arc<UOp>>) {
self.range_map.insert(UOpKey(Arc::clone(uop)), (input_ranges, output_ranges));
}
pub fn get_ranges(&self, uop: &Arc<UOp>) -> Option<&UOpRanges> {
self.range_map.get(&UOpKey(Arc::clone(uop)))
}
pub fn range_counter(&self) -> usize {
self.range_idx
}
}
#[derive(Default)]
pub struct SimplifyCache {
cache: HashMap<u64, Arc<UOp>>,
}
impl SimplifyCache {
#[inline]
fn get_or_simplify(&mut self, input: &Arc<UOp>, f: impl FnOnce() -> Arc<UOp>) -> Arc<UOp> {
if let Some(cached) = self.cache.get(&input.id) {
return cached.clone();
}
let result = f();
self.cache.insert(input.id, result.clone());
result
}
}
#[allow(clippy::mutable_key_type)]
#[instrument(skip(sink), fields(sink_id = sink.id))]
pub fn run_rangeify(sink: Arc<UOp>) -> morok_ir::Result<(Arc<UOp>, IndexingContext)> {
let mut ctx = IndexingContext::new();
let mut simplify_cache = SimplifyCache::default();
crate::rewrite::graph_rewrite_bottom_up(pm_generate_realize_map(), sink.clone(), &mut ctx);
let consumer_map = sink.get_consumer_map();
let forward_topo: Vec<_> = sink.toposort().into_iter().rev().collect();
assign_ranges(&forward_topo, &consumer_map, &mut ctx, &mut simplify_cache)?;
let rangeify_matcher = super::patterns::apply_rangeify_patterns();
let transformed_sink = crate::rewrite::graph_rewrite_bottom_up(&rangeify_matcher, sink, &mut ctx);
Ok((transformed_sink, ctx))
}
fn pm_generate_realize_map() -> &'static crate::TypedPatternMatcher<IndexingContext> {
crate::cached_patterns! {
@context IndexingContext;
x @ Sink { sources: _ } => |x, ctx| {
for src in x.op().sources() {
if !is_always_contiguous(&src) {
ctx.mark_realize_all(&src).ok();
}
}
None
},
x @ Store { index: _, value: _ } => |x, ctx| { ctx.mark_realize_all(x).ok(); None },
x @ BufferView { buffer: _ } => |x, ctx| { ctx.mark_realize_all(x).ok(); None },
x @ Reduce { src: _, ranges, reduce_op: _ } => |x, ctx| {
if ranges.iter().any(|r| matches!(r.op(), Op::Range { axis_type, .. } if *axis_type == AxisType::Outer)) {
ctx.mark_realize_all(x).ok();
}
None
},
x @ Copy { src: _ } => |x, ctx| {
ctx.mark_realize_all(x).ok();
for src in x.op().sources() {
if !is_always_contiguous(&src) {
ctx.mark_realize_all(&src).ok();
}
}
None
},
x @ Contiguous { src: _ } => |x, ctx| { ctx.mark_realize_all(x).ok(); None },
x @ Assign { target: _, value: _ } => |x, ctx| {
ctx.mark_realize_all(x).ok();
for src in x.op().sources() {
if !is_always_contiguous(&src) {
ctx.mark_realize_all(&src).ok();
}
}
None
},
x @ MStack { buffers: _ } => |x, ctx| {
for src in x.op().sources() {
if !is_always_contiguous(&src) {
ctx.mark_realize_all(&src).ok();
}
}
None
},
x @ MSelect { device_index: _ } => |x, ctx| {
for src in x.op().sources() {
if !is_always_contiguous(&src) {
ctx.mark_realize_all(&src).ok();
}
}
None
},
}
}
pub(crate) fn is_always_contiguous(uop: &Arc<UOp>) -> bool {
matches!(
uop.op(),
Op::Contiguous { .. }
| Op::Assign { .. }
| Op::Copy { .. }
| Op::Buffer { .. }
| Op::BufferView { .. }
| Op::Const(_)
| Op::Bind { .. }
| Op::Device(_)
| Op::MSelect { .. }
| Op::MStack { .. }
| Op::Param { .. }
| Op::DefineLocal(_)
| Op::DefineReg { .. }
| Op::Load { .. }
| Op::Kernel { .. }
)
}
fn is_const_true(uop: &Arc<UOp>) -> bool {
match uop.op() {
Op::Const(cv) => matches!(cv.0, ConstValue::Bool(true)),
Op::Binary(BinaryOp::Or, a, b) => is_const_true(a) && is_const_true(b),
_ => false,
}
}
#[instrument(skip(uop, consumer_rngs, ctx), fields(uop_id = uop.id))]
pub(crate) fn merge_consumer_ranges(
uop: &Arc<UOp>,
consumer_rngs: &[Vec<Arc<UOp>>],
ctx: &mut IndexingContext,
) -> morok_ir::Result<Vec<Arc<UOp>>> {
let Some(shape) = uop.shape()? else {
return Ok(Vec::new());
};
let num_dims = shape.len();
let mut all_rngs: Vec<Vec<Arc<UOp>>> = vec![Vec::new(); num_dims];
for consumer_rng in consumer_rngs {
for (dim_idx, range) in consumer_rng.iter().enumerate() {
if dim_idx < num_dims {
all_rngs[dim_idx].push(Arc::clone(range));
}
}
}
let mut out_rngs = Vec::new();
let mut realize_axes = Vec::new();
let all_all_same = all_rngs.iter().all(|dim_ranges| {
if dim_ranges.is_empty() {
return false;
}
if dim_ranges.iter().skip(1).all(|r| Arc::ptr_eq(&dim_ranges[0], r)) {
return true;
}
let indices: Vec<_> = dim_ranges.iter().map(|r| r.get_idx()).collect();
all_ranges_same(&indices)
});
for (dim_idx, dim_ranges) in all_rngs.iter().enumerate() {
if dim_ranges.is_empty() {
out_rngs.push(ctx.new_range(&shape[dim_idx], AxisType::Loop));
realize_axes.push(dim_idx);
continue;
}
if dim_ranges.iter().skip(1).all(|r| Arc::ptr_eq(&dim_ranges[0], r)) && all_all_same {
out_rngs.push(Arc::clone(&dim_ranges[0]));
continue;
}
let indices: Vec<_> = dim_ranges.iter().map(|r| r.get_idx()).collect();
let valids: Vec<_> = dim_ranges.iter().map(|r| r.get_valid()).collect();
let ranges_same = all_ranges_same(&indices);
if all_all_same {
debug!(dim_idx, ranges_same, all_all_same, "merge_consumer_ranges: merging dimension");
let merged_idx = Arc::clone(&indices[0]);
let merged_valid = if valids.len() == 1 {
Arc::clone(&valids[0])
} else {
valids.iter().skip(1).try_fold(Arc::clone(&valids[0]), |acc, v| acc.try_or_op(v))?
};
let merged_range = if is_const_true(&merged_valid) {
merged_idx
} else {
let raw = UOp::try_where(merged_valid, merged_idx, UOp::invalid_marker())?;
crate::rewrite::graph_rewrite(crate::symbolic::patterns::symbolic(), raw, &mut ())
};
out_rngs.push(merged_range);
} else {
debug!(dim_idx, "merge_consumer_ranges: creating NEW Loop range (ranges not compatible)");
out_rngs.push(ctx.new_range(&shape[dim_idx], AxisType::Loop));
realize_axes.push(dim_idx);
}
}
if !realize_axes.is_empty() {
warn!(realize_axes = ?realize_axes, "range conflict detected - marking axes for realization");
ctx.mark_realize(uop, realize_axes.clone());
}
Ok(out_rngs)
}
#[allow(clippy::mutable_key_type)]
#[instrument(skip_all)]
fn assign_ranges(
reverse_topo: &[Arc<UOp>],
consumer_map: &HashMap<UOpKey, Vec<Arc<UOp>>>,
ctx: &mut IndexingContext,
simplify_cache: &mut SimplifyCache,
) -> morok_ir::Result<()> {
let mut ending_ranges: HashMap<UOpKey, Vec<Arc<UOp>>> = HashMap::new();
for x in reverse_topo {
if matches!(x.op(), Op::Device(_) | Op::Unique(_)) {
continue;
}
if x.dtype().scalar() == Some(morok_dtype::ScalarDType::Index) {
continue;
}
if matches!(x.op(), Op::MStack { .. } | Op::MSelect { .. }) {
continue;
}
let _span = info_span!("assign_range", uop_id = x.id, op = x.op().as_ref()).entered();
let consumers: Vec<_> = consumer_map.get(&UOpKey(x.clone())).cloned().unwrap_or_default();
let consumer_rngs: Vec<Vec<Arc<UOp>>> =
consumers.iter().filter_map(|c| ctx.get_ranges(c).map(|(inp, _)| inp.clone())).collect();
debug!(
num_consumers = consumers.len(),
consumer_rngs_len = consumer_rngs.len(),
consumer_ids = ?consumers.iter().map(|c| c.id).collect::<Vec<_>>(),
"Consumer info"
);
let mut inherited_ending: Vec<Arc<UOp>> = Vec::new();
for consumer in &consumers {
inherited_ending.extend(ending_ranges.get(&UOpKey(consumer.clone())).cloned().unwrap_or_default());
}
if !inherited_ending.is_empty() {
debug!(
node_id = x.id,
inherited_count = inherited_ending.len(),
consumer_ids = ?consumers.iter().map(|c| c.id).collect::<Vec<_>>(),
"ending_ranges: node inherits from consumers"
);
}
ending_ranges.insert(UOpKey(x.clone()), inherited_ending);
let mut out_rngs = if ctx.should_realize(x) {
if let Some(shape) = x.shape()? {
debug!(
node_id = x.id,
op = x.op().as_ref(),
dims = shape.len(),
"REALIZE via realize_map (fresh ranges)"
);
let rngs: Vec<_> = shape.iter().map(|s| ctx.new_range(s, AxisType::Loop)).collect();
let axes: Vec<usize> = (0..shape.len()).collect();
ctx.realize_map.insert(UOpKey(x.clone()), Some(axes));
ending_ranges.insert(UOpKey(x.clone()), Vec::new());
rngs
} else {
continue;
}
} else if consumer_rngs.is_empty() {
continue;
} else if consumer_rngs.len() == 1 {
consumer_rngs[0].clone()
} else {
merge_consumer_ranges(x, &consumer_rngs, ctx)?
};
debug!(should_realize = ctx.should_realize(x), out_rngs_len = out_rngs.len(), "output ranges computed");
let ending = ending_ranges.get(&UOpKey(x.clone())).cloned().unwrap_or_default();
if !ending.is_empty() {
debug!(
ending_count = ending.len(),
triggers_realization = matches!(x.op(), Op::ReduceAxis { .. }) || is_elementwise_op(x),
"Ending ranges detected (pre-in_rngs check)"
);
}
let filtered_ending = ending.clone();
if !filtered_ending.is_empty() && (matches!(x.op(), Op::ReduceAxis { .. }) || is_elementwise_op(x)) {
if let Some(shape) = x.shape().ok().flatten() {
let mut realize_axes: Vec<usize> = ctx.get_realize_axes(x).cloned().unwrap_or_default();
for (i, _r) in out_rngs.iter().enumerate() {
if realize_axes.contains(&i) {
continue;
}
realize_axes.push(i);
}
debug!(
node_id = x.id,
op = x.op().as_ref(),
ending_count = ending.len(),
realize_axes = ?realize_axes,
"SELECTIVE REALIZATION via ending_ranges"
);
ending_ranges.insert(UOpKey(x.clone()), Vec::new());
if !realize_axes.is_empty() {
ctx.mark_realize(x, realize_axes.clone());
out_rngs = out_rngs
.iter()
.enumerate()
.map(|(i, r)| {
if realize_axes.contains(&i) {
ctx.new_range(&shape[i], AxisType::Loop)
} else {
Arc::clone(r)
}
})
.collect();
}
} else {
ending_ranges.insert(UOpKey(x.clone()), Vec::new());
}
}
let in_rngs = match x.op() {
Op::Reshape { src, .. }
| Op::Permute { src, .. }
| Op::Expand { src, .. }
| Op::Pad { src, .. }
| Op::Shrink { src, .. }
| Op::Flip { src, .. } => {
if let Some(in_shape) = src.shape()? {
apply_movement_op(x.op(), in_shape, &out_rngs, simplify_cache)
} else {
out_rngs.clone()
}
}
Op::ReduceAxis { src, axes, .. } => {
if let Some(in_shape) = src.shape()? {
if tracing::enabled!(tracing::Level::TRACE) {
let out_shape = x.shape()?;
trace!(
uop.id = x.id,
reduce.axes = ?axes,
in_shape.len = in_shape.len(),
out_shape.len = ?out_shape.as_ref().map(|s| s.len()),
out_rngs.len = out_rngs.len(),
"ReduceAxis range assignment"
);
for (idx, rng) in out_rngs.iter().enumerate() {
match rng.op() {
Op::Binary(binop, a, b) => {
trace!(
range.index = idx,
range.id = rng.id,
op = "Binary",
binary_op = ?binop,
left.id = a.id,
right.id = b.id,
"ReduceAxis out_rngs entry"
);
}
Op::Range { axis_id, axis_type, .. } => {
trace!(
range.index = idx,
range.id = rng.id,
op = "Range",
axis.id = ?axis_id,
axis.type_ = ?axis_type,
"ReduceAxis out_rngs entry"
);
}
_ => {
trace!(
range.index = idx,
range.id = rng.id,
op = ?std::mem::discriminant(rng.op()),
"ReduceAxis out_rngs entry"
);
}
}
}
}
let mut rngs = Vec::with_capacity(in_shape.len());
for (i, s) in in_shape.iter().enumerate() {
if axes.contains(&i) {
rngs.push(ctx.new_range(s, AxisType::Reduce));
} else if i < out_rngs.len() {
rngs.push(Arc::clone(&out_rngs[i]));
trace!(dim.index = i, range.id = out_rngs[i].id, "ReduceAxis using existing out_rngs");
} else {
rngs.push(ctx.new_range(s, AxisType::Loop));
}
}
rngs
} else {
out_rngs.clone()
}
}
_ => out_rngs.clone(),
};
debug!(in_rngs_len = in_rngs.len(), "input ranges computed");
if let Op::Expand { new_shape, .. } = x.op() {
let shape_is_static = extract_shape_from_uop(new_shape).iter().all(|s| match s {
SInt::Const(_) | SInt::Infer => true,
SInt::Symbolic(uop) => !matches!(uop.op(), Op::Range { .. }),
});
debug!(
expand_id = x.id,
shape_is_static = shape_is_static,
in_rngs_len = in_rngs.len(),
out_rngs_len = out_rngs.len(),
in_rngs_ids = ?in_rngs.iter().map(|r| (r.id, format!("{:?}", std::mem::discriminant(r.op())))).collect::<Vec<_>>(),
out_rngs_ids = ?out_rngs.iter().map(|r| (r.id, format!("{:?}", std::mem::discriminant(r.op())))).collect::<Vec<_>>(),
"ending_ranges: EXPAND being processed"
);
if shape_is_static {
let mut changed_ranges: Vec<Arc<UOp>> = Vec::new();
for (inp, out) in in_rngs.iter().zip(out_rngs.iter()) {
if !Arc::ptr_eq(inp, out) {
changed_ranges.extend(collect_ranges_from_uop(out));
}
}
if !changed_ranges.is_empty() {
debug!(
expand_id = x.id,
changed_ranges_count = changed_ranges.len(),
changed_range_ids = ?changed_ranges.iter().map(|r| r.id).collect::<Vec<_>>(),
"ending_ranges: EXPAND marking ranges as ending"
);
let mut ending = ending_ranges.get(&UOpKey(x.clone())).cloned().unwrap_or_default();
ending.extend(changed_ranges);
ending_ranges.insert(UOpKey(x.clone()), ending);
}
}
}
ctx.set_ranges(x, in_rngs, out_rngs);
}
Ok(())
}
pub fn apply_movement_op(
op: &Op,
in_shape: &[SInt],
rngs: &[Arc<UOp>],
simplify_cache: &mut SimplifyCache,
) -> Vec<Arc<UOp>> {
match op {
Op::Shrink { begins, .. } => {
let begin_uops = extract_shape_uops(begins);
rngs.iter()
.zip(begin_uops.iter())
.map(|(rng, begin)| {
if is_const_zero(begin) {
Arc::clone(rng)
} else {
rng.try_add(begin).expect("SHRINK: try_add failed")
}
})
.collect()
}
Op::Permute { axes, .. } => {
let inv_perm = argsort(axes);
inv_perm.iter().map(|&i| Arc::clone(&rngs[i])).collect()
}
Op::Flip { axes: flips, .. } => rngs
.iter()
.zip(in_shape.iter())
.zip(flips.iter())
.map(|((rng, shape), &flip)| {
if !flip {
Arc::clone(rng)
} else {
let shape_uop = shape.to_uop(morok_dtype::DType::Index);
let shape_minus_1 = shape_uop.try_sub(&UOp::index_const(1)).unwrap();
shape_minus_1.try_sub(rng).unwrap()
}
})
.collect(),
Op::Expand { new_shape, .. } => {
let new_shape_vals = extract_shape_from_uop(new_shape);
let padded_rngs: Vec<Arc<UOp>> = if rngs.len() < new_shape_vals.len() {
let padding = new_shape_vals.len() - rngs.len();
let mut v = Vec::with_capacity(new_shape_vals.len());
for _ in 0..padding {
v.push(UOp::index_const(0));
}
v.extend(rngs.iter().cloned());
v
} else {
rngs.to_vec()
};
let padded_in_shape: Vec<SInt> = if in_shape.len() < new_shape_vals.len() {
let padding = new_shape_vals.len() - in_shape.len();
let mut v = Vec::with_capacity(new_shape_vals.len());
for _ in 0..padding {
v.push(SInt::Const(1));
}
v.extend(in_shape.iter().cloned());
v
} else {
in_shape.to_vec()
};
padded_rngs
.iter()
.zip(padded_in_shape.iter())
.zip(new_shape_vals.iter())
.map(|((rng, in_sh), out_sh)| {
let expanding = match (in_sh, out_sh) {
(SInt::Const(1), SInt::Const(n)) if *n > 1 => true,
(SInt::Const(1), SInt::Symbolic(_)) => true,
_ => false,
};
if expanding { UOp::index_const(0) } else { Arc::clone(rng) }
})
.collect()
}
Op::Pad { begin_pads, end_pads, .. } => {
let begin_uops = extract_shape_uops(begin_pads);
let end_uops = extract_shape_uops(end_pads);
rngs.iter()
.zip(in_shape.iter())
.zip(begin_uops.iter().zip(end_uops.iter()))
.map(|((rng, shape), (begin, end))| {
if is_const_zero(begin) && is_const_zero(end) {
return Arc::clone(rng);
}
let shape_plus_begin = shape.to_uop(morok_dtype::DType::Index).try_add(begin).unwrap();
let valid_low = rng.try_cmplt(begin).unwrap().not();
let valid_high = rng.try_cmplt(&shape_plus_begin).unwrap();
let valid = valid_low.try_and_op(&valid_high).unwrap();
static PAD_SIMPLIFY: std::sync::LazyLock<crate::TypedPatternMatcher> =
std::sync::LazyLock::new(|| {
crate::symbolic::patterns::symbolic()
+ crate::symbolic::valid_simplification::pm_simplify_valid()
});
let valid = simplify_cache.get_or_simplify(&valid, || {
crate::rewrite::graph_rewrite(&*PAD_SIMPLIFY, valid.clone(), &mut ())
});
let adjusted_rng = rng.try_sub(begin).unwrap();
UOp::try_where(valid, adjusted_rng, UOp::invalid_marker()).unwrap()
})
.collect()
}
Op::Reshape { new_shape, .. } => {
let new_shape_vals = extract_shape_from_uop(new_shape);
if in_shape.len() == new_shape_vals.len() {
let mut is_same_shape = true;
for (in_dim, out_dim) in in_shape.iter().zip(new_shape_vals.iter()) {
match (in_dim, out_dim) {
(SInt::Const(a), SInt::Const(b)) if a == b => continue,
(SInt::Symbolic(a), SInt::Symbolic(b)) if a.id == b.id => continue,
_ => {
is_same_shape = false;
break;
}
}
}
if is_same_shape {
return rngs.to_vec();
}
}
with_placeholder_canonicalization(rngs, |canonical| {
apply_reshape_core(in_shape, &new_shape_vals, canonical, simplify_cache)
})
}
_ => panic!("apply_movement_op called with non-movement op: {:?}", op),
}
}
fn apply_reshape_core(
in_shape: &[SInt],
out_shape: &[SInt],
rngs: &[Arc<UOp>],
simplify_cache: &mut SimplifyCache,
) -> Vec<Arc<UOp>> {
use morok_ir::rewrite::graph_rewrite;
let padded_rngs: Vec<Arc<UOp>> = if rngs.len() < out_shape.len() {
let padding = out_shape.len() - rngs.len();
let mut v = Vec::with_capacity(out_shape.len());
for _ in 0..padding {
v.push(UOp::index_const(0));
}
v.extend(rngs.iter().cloned());
v
} else {
rngs.to_vec()
};
let mut acc = UOp::index_const(1);
let mut axes_in = Vec::new();
for (shape_dim, rng) in out_shape.iter().zip(padded_rngs.iter()).rev() {
axes_in.push(acc.try_mul(rng).unwrap());
let dim_uop = shape_dim.to_uop(morok_dtype::DType::Index);
acc = acc.try_mul(&dim_uop).unwrap();
}
let combined = axes_in.into_iter().reduce(|a, b| a.try_add(&b).unwrap()).unwrap_or_else(|| UOp::index_const(0));
let mut axes_out = Vec::new();
let mut remaining = combined;
for shape_dim in in_shape.iter().rev() {
let dim_uop = shape_dim.to_uop(morok_dtype::DType::Index);
axes_out.push(remaining.try_mod(&dim_uop).unwrap());
remaining = remaining.try_div(&dim_uop).unwrap();
}
axes_out.reverse();
static RESHAPE_SIMPLIFY: std::sync::LazyLock<crate::TypedPatternMatcher> = std::sync::LazyLock::new(|| {
crate::symbolic::patterns::symbolic()
+ crate::symbolic::valid_simplification::pm_simplify_valid()
+ crate::symbolic::valid_simplification::pm_drop_and_clauses()
});
let sink = UOp::sink(axes_out);
let simplified = simplify_cache.get_or_simplify(&sink, || graph_rewrite(&*RESHAPE_SIMPLIFY, sink.clone(), &mut ()));
match simplified.op() {
Op::Sink { sources } => sources.iter().cloned().collect(),
_ => vec![simplified],
}
}
pub fn apply_reshape_ranges(in_shape: &[SInt], out_shape: &[SInt], rngs: &[Arc<UOp>]) -> Vec<Arc<UOp>> {
let mut cache = SimplifyCache::default();
with_placeholder_canonicalization(rngs, |canonical| apply_reshape_core(in_shape, out_shape, canonical, &mut cache))
}
fn with_placeholder_canonicalization(rngs: &[Arc<UOp>], f: impl FnOnce(&[Arc<UOp>]) -> Vec<Arc<UOp>>) -> Vec<Arc<UOp>> {
let sink = UOp::sink(rngs.to_vec());
let ranges_in_expr: Vec<Arc<UOp>> = sink.ranges().clone();
#[allow(clippy::mutable_key_type)]
let mut sub_map: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
#[allow(clippy::mutable_key_type)]
let mut reverse_map: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
for (i, r) in ranges_in_expr.iter().enumerate() {
let Op::Range { end, .. } = r.op() else { continue };
let placeholder = UOp::range_axis(end.clone(), AxisId::Renumbered(i), AxisType::Placeholder);
sub_map.insert(UOpKey(r.clone()), placeholder.clone());
reverse_map.insert(UOpKey(placeholder), r.clone());
}
if sub_map.is_empty() {
return f(rngs);
}
let canonical_sink = sink.substitute(&sub_map);
let canonical_rngs: Vec<Arc<UOp>> = match canonical_sink.op() {
Op::Sink { sources } => sources.iter().cloned().collect(),
_ => vec![canonical_sink],
};
let result = f(&canonical_rngs);
let result_sink = UOp::sink(result);
let restored = result_sink.substitute(&reverse_map);
let output: Vec<Arc<UOp>> = match restored.op() {
Op::Sink { sources } => sources.iter().cloned().collect(),
_ => vec![restored],
};
debug_assert!(
!output.iter().any(|r| {
UOp::sink(vec![r.clone()])
.ranges()
.iter()
.any(|rng| matches!(rng.op(), Op::Range { axis_type: AxisType::Placeholder, .. }))
}),
"Placeholder-typed ranges leaked into output"
);
output
}
fn is_const_zero(uop: &Arc<UOp>) -> bool {
matches!(uop.op(), Op::Const(cv) if cv.0 == ConstValue::Int(0))
}
fn extract_shape_uops(uop: &Arc<UOp>) -> Vec<Arc<UOp>> {
match uop.op() {
Op::Vectorize { elements } => elements.to_vec(),
Op::Const(_) => vec![uop.clone()],
_ => panic!("Expected vectorize or constant for shape uops, got {:?}", uop.op()),
}
}
fn extract_shape_from_uop(uop: &Arc<UOp>) -> Vec<SInt> {
match uop.op() {
Op::Vectorize { elements } => elements
.iter()
.map(|elem| match elem.op() {
Op::Const(cv) => match cv.0 {
ConstValue::Int(n) => SInt::Const(n as usize),
_ => SInt::Symbolic(Arc::clone(elem)),
},
_ => SInt::Symbolic(Arc::clone(elem)),
})
.collect(),
Op::Const(cv) => match cv.0 {
ConstValue::Int(n) => vec![SInt::Const(n as usize)],
_ => panic!("Expected int constant for shape"),
},
Op::VConst { values } if values.is_empty() => vec![],
Op::VConst { values } => values
.iter()
.map(|cv| match cv {
ConstValue::Int(n) => SInt::Const(*n as usize),
ConstValue::UInt(n) => SInt::Const(*n as usize),
_ => panic!("Expected int/uint constant in VConst shape"),
})
.collect(),
_ => panic!("Expected vectorize or constant for shape, got {:?}", uop.op()),
}
}
pub fn ranges_equal(ranges1: &[Arc<UOp>], ranges2: &[Arc<UOp>]) -> bool {
ranges1.len() == ranges2.len() && ranges1.iter().zip(ranges2).all(|(r1, r2)| Arc::ptr_eq(r1, r2))
}
fn ranges_compatible(a: &Arc<UOp>, b: &Arc<UOp>) -> bool {
Arc::ptr_eq(a, b) || uop_equal(a, b)
}
pub fn all_ranges_same(ranges: &[Arc<UOp>]) -> bool {
if ranges.is_empty() {
return true;
}
let first_idx = ranges[0].get_idx();
ranges.iter().skip(1).all(|r| {
let idx = r.get_idx();
ranges_compatible(&first_idx, &idx)
})
}
pub fn uop_equal(a: &Arc<UOp>, b: &Arc<UOp>) -> bool {
if Arc::ptr_eq(a, b) {
return true;
}
if std::mem::discriminant(a.op()) != std::mem::discriminant(b.op()) {
return false;
}
if a.dtype() != b.dtype() {
return false;
}
if let (Op::Const(cv_a), Op::Const(cv_b)) = (a.op(), b.op()) {
return cv_a.0 == cv_b.0;
}
if let (
Op::Range { end: end_a, axis_id: id_a, axis_type: type_a, .. },
Op::Range { end: end_b, axis_id: id_b, axis_type: type_b, .. },
) = (a.op(), b.op())
{
return id_a == id_b && type_a == type_b && uop_equal(end_a, end_b);
}
let a_srcs = a.op().sources();
let b_srcs = b.op().sources();
if a_srcs.len() != b_srcs.len() {
return false;
}
a_srcs.iter().zip(b_srcs.iter()).all(|(sa, sb)| uop_equal(sa, sb))
}
pub fn is_dead_axis(range: &Arc<UOp>) -> bool {
if !matches!(range.op(), Op::Range { .. }) {
return false;
}
match range.vmax() {
ConstValue::Int(v) => *v <= 0,
ConstValue::UInt(v) => *v == 0,
_ => false,
}
}
#[allow(clippy::mutable_key_type)]
pub fn no_range(uop: &Arc<UOp>) -> bool {
let in_scope_ranges = uop.in_scope_ranges();
!in_scope_ranges.iter().any(|key| matches!(key.0.op(), Op::Range { .. }))
}
pub fn range_size_as_i64(range: &Arc<UOp>) -> Option<i64> {
if let Op::Range { end, .. } = range.op() {
match end.op() {
Op::Const(cv) => match cv.0 {
ConstValue::Int(n) => Some(n),
ConstValue::UInt(n) => Some(n as i64),
_ => None,
},
_ => None,
}
} else {
None
}
}
pub fn is_identity_value(value: &ConstValue, op: &BinaryOp, is_right: bool) -> bool {
match (op, value) {
(BinaryOp::Add, ConstValue::Int(0)) => true,
(BinaryOp::Add, ConstValue::Float(f)) if *f == 0.0 => true,
(BinaryOp::Sub, ConstValue::Int(0)) if is_right => true,
(BinaryOp::Sub, ConstValue::Float(f)) if is_right && *f == 0.0 => true,
(BinaryOp::Mul, ConstValue::Int(1)) => true,
(BinaryOp::Mul, ConstValue::Float(f)) if *f == 1.0 => true,
(BinaryOp::Idiv, ConstValue::Int(1)) if is_right => true,
(BinaryOp::Fdiv, ConstValue::Float(f)) if is_right && *f == 1.0 => true,
(BinaryOp::Or, ConstValue::Int(0)) => true,
(BinaryOp::Xor, ConstValue::Int(0)) => true,
(BinaryOp::And, ConstValue::Int(-1)) => true,
_ => false,
}
}
pub fn is_zero_value(value: &ConstValue, op: &BinaryOp) -> bool {
match (op, value) {
(BinaryOp::Mul, ConstValue::Int(0)) => true,
(BinaryOp::Mul, ConstValue::Float(f)) if *f == 0.0 => true,
(BinaryOp::And, ConstValue::Int(0)) => true,
_ => false,
}
}
pub fn get_const_value(uop: &Arc<UOp>) -> Option<ConstValue> {
match uop.op() {
Op::Const(cv) => Some(cv.0),
_ => None,
}
}
pub fn is_const(uop: &Arc<UOp>, value: &ConstValue) -> bool {
get_const_value(uop).as_ref() == Some(value)
}
pub fn is_zero_size(uop: &Arc<UOp>) -> bool {
uop.shape().ok().flatten().map(|shape| shape.iter().any(|dim| matches!(dim, SInt::Const(0)))).unwrap_or(false)
}
pub fn is_void(dtype: &DType) -> bool {
*dtype == DType::Void
}
pub fn get_binary_op(uop: &Arc<UOp>) -> Option<BinaryOp> {
match uop.op() {
Op::Binary(op, _, _) => Some(*op),
_ => None,
}
}
pub fn is_local_bufferize(uop: &Arc<UOp>) -> bool {
if let Op::Bufferize { opts, .. } = uop.op() { opts.addrspace == morok_ir::AddrSpace::Local } else { false }
}
#[allow(clippy::mutable_key_type)]
fn collect_ranges_from_uop(uop: &Arc<UOp>) -> Vec<Arc<UOp>> {
use std::collections::HashSet;
let mut ranges = Vec::new();
let mut seen = HashSet::new();
for node in uop.toposort() {
if matches!(node.op(), Op::Range { .. }) {
let key = UOpKey(Arc::clone(&node));
if seen.insert(key) {
ranges.push(node);
}
}
}
ranges
}
fn is_elementwise_op(uop: &Arc<UOp>) -> bool {
matches!(uop.op(), Op::Binary(..) | Op::Unary(..) | Op::Ternary(..) | Op::Cast { .. } | Op::BitCast { .. })
}