use std::collections::HashSet;
use std::sync::Arc;
use crate::argsort;
use morok_device::DeviceSpec;
use morok_dtype::{AddrSpace, DType, ScalarDType};
use morok_ir::{AxisId, AxisType, BinaryOp, BufferizeOpts, ConstValue, Op, ReduceOp, UOp, UOpKey, UnaryOp};
use smallvec::SmallVec;
use tracing::trace;
use crate::TypedPatternMatcher;
use crate::rangeify::transforms::{cast_to_dtype, get_range_size, partition_reduce_ranges};
pub use super::indexing::{is_dead_axis, ranges_equal};
use super::indexing::IndexingContext;
use super::kernel::{KernelContext, LocalAddBufferContext};
use super::kernel::{PcontigConfig, SplitReduceOpConfig, split_reduceop};
use super::transforms::transform_sources_with_bufferize;
fn is_scalar_shape(shape: &Arc<UOp>) -> bool {
match shape.op() {
Op::Vectorize { elements } => elements.is_empty(),
_ => false,
}
}
fn has_zero_size(uop: &Arc<UOp>) -> bool {
match uop.shape() {
Ok(Some(shape)) => shape.iter().any(|d| d.as_const() == Some(0)),
_ => false,
}
}
pub fn is_cheap_to_inline(op: &Op) -> bool {
matches!(
op,
Op::Const(_)
| Op::Unique(_)
| Op::Device(_)
| Op::Noop
| Op::DefineVar { .. }
| Op::DefineReg { .. }
| Op::VConst { .. }
| Op::Unary(..)
| Op::Binary(..)
| Op::Ternary(..)
| Op::Cast { .. }
| Op::BitCast { .. }
| Op::Gep { .. }
| Op::Vectorize { .. }
| Op::PointerIndex { .. }
)
}
fn unary_in_reduce_context(compute: &Arc<UOp>) -> bool {
if !matches!(compute.op(), Op::Unary(..)) {
return false;
}
compute
.in_scope_ranges()
.iter()
.any(|key| if let Op::Range { axis_type, .. } = key.0.op() { *axis_type == AxisType::Reduce } else { false })
}
fn block_reduce_unary_inline(_buf: &Arc<UOp>) -> Option<Arc<UOp>> {
None
}
pub fn is_always_run_op(op: &Op) -> bool {
matches!(op, Op::Contiguous { .. } | Op::Copy { .. } | Op::Assign { .. })
}
pub fn is_elementwise(uop: &Arc<UOp>) -> bool {
matches!(uop.op(), Op::Binary(..) | Op::Ternary(..))
}
pub fn early_rewrites() -> TypedPatternMatcher {
crate::patterns! {
Detach(x) ~> |x| x.clone(),
ContiguousBackward(x) ~> |x| x.clone(),
Reshape { src, new_shape } => |src, new_shape| {
if is_scalar_shape(new_shape) {
return Some(src.clone());
}
if matches!(src.op(), Op::Reduce { .. }) {
return Some(src.clone());
}
None
},
reduce @ ReduceAxis { src: x } if has_zero_size(x) && !has_zero_size(reduce) => {
let Op::ReduceAxis { reduce_op, .. } = reduce.op() else { return None };
Some(crate::symbolic::dce::reduce_identity(*reduce_op, reduce.dtype()))
},
x if !matches!(x.op(), Op::Sink { .. }) && has_zero_size(x) => {
Some(x.const_like(0))
}
}
}
pub type ReplaceContiguousCtx = std::collections::HashMap<UOpKey, Arc<UOp>>;
pub fn replace_contiguous() -> TypedPatternMatcher<ReplaceContiguousCtx> {
crate::patterns! {
@context ReplaceContiguousCtx;
contig @ Contiguous { src, .. } if src.op().is_movement() => |contig, src, ctx| {
found_contiguous(ctx, contig, src)
},
x if matches!(x.op(), Op::Unary(..) | Op::Binary(..) | Op::Ternary(..)) => |x, ctx| {
if ctx.is_empty() { return None; }
let sources = x.op().sources();
let mut new_sources = Vec::with_capacity(sources.len());
let mut any_changed = false;
for src in sources.iter() {
let key = UOpKey(src.clone());
if let Some(replacement) = ctx.get(&key) {
new_sources.push(replacement.clone());
any_changed = true;
} else {
new_sources.push(src.clone());
}
}
if any_changed { Some(x.with_sources(new_sources)) } else { None }
},
}
}
#[allow(clippy::mutable_key_type)] fn found_contiguous(ctx: &mut ReplaceContiguousCtx, contig: &Arc<UOp>, src: &Arc<UOp>) -> Option<Arc<UOp>> {
let base = src.base();
let mut adjusted_contig = contig.clone();
let mut x = src.clone();
while !Arc::ptr_eq(&x, &base) {
match x.op() {
Op::Permute { src: inner, axes } => {
let inv = argsort(axes);
adjusted_contig = adjusted_contig.try_permute(inv).ok()?;
x = inner.clone();
}
Op::Reshape { src: inner, .. } => {
let inner_shape = inner.shape().ok()??;
adjusted_contig = adjusted_contig.try_reshape(inner_shape).ok()?;
x = inner.clone();
}
_ => return None, }
}
ctx.insert(UOpKey(base.clone()), adjusted_contig);
None }
pub fn apply_rangeify_patterns() -> TypedPatternMatcher<IndexingContext> {
crate::patterns! {
@context IndexingContext;
x @ ReduceAxis { src: _ } => |x, ctx| convert_reduceaxis_with_context(x, ctx),
x @ Pad { src: _, begin_pads: _, end_pads: _ } => |x, ctx| convert_pad_to_where(x, ctx),
x => |x, ctx| apply_bufferize_transform(x, ctx),
x if x.op().is_movement() => |x, ctx| remove_movement_op(x, ctx),
}
}
fn apply_bufferize_transform(x: &Arc<UOp>, ctx: &mut IndexingContext) -> Option<Arc<UOp>> {
if let Some(new_sources) = transform_sources_with_bufferize(x, ctx) {
let new_node = x.with_sources(new_sources);
if let Some((in_rngs, out_rngs)) = ctx.get_ranges(x) {
ctx.set_ranges(&new_node, in_rngs.clone(), out_rngs.clone());
}
if let Some(realize_axes) = ctx.get_realize_axes(x).cloned() {
ctx.mark_realize(&new_node, realize_axes);
}
return Some(new_node);
}
None
}
fn convert_pad_to_where(x: &Arc<UOp>, ctx: &mut IndexingContext) -> Option<Arc<UOp>> {
let (input_ranges, output_ranges) = ctx.get_ranges(x)?;
let input_ranges = input_ranges.clone();
let output_ranges = output_ranges.clone();
let mut combined_valid: Option<Arc<UOp>> = None;
for r in &input_ranges {
let valid = r.get_valid();
combined_valid = Some(match combined_valid {
None => valid,
Some(acc) => acc.try_and_op(&valid).unwrap_or(valid),
});
}
let combined_valid = combined_valid?;
if matches!(combined_valid.op(), Op::Const(cv) if cv.0 == ConstValue::Bool(true)) {
return None;
}
let pad_src = x.op().sources().first()?.clone();
let zero = UOp::const_(x.dtype(), ConstValue::zero(x.dtype().scalar().unwrap()));
let ret = UOp::try_where(combined_valid, pad_src, zero).ok()?;
ctx.set_ranges(&ret, input_ranges, output_ranges);
Some(ret)
}
fn remove_movement_op(x: &Arc<UOp>, ctx: &mut IndexingContext) -> Option<Arc<UOp>> {
let src = x.op().sources().first()?.clone();
if matches!(
src.op(),
Op::Buffer { .. } | Op::BufferView { .. } | Op::MStack { .. } | Op::MSelect { .. } | Op::After { .. }
) {
let (input_ranges, _) = ctx.get_ranges(x)?;
return UOp::index().buffer(src).indices(input_ranges.clone()).call().ok();
}
if matches!(src.op(), Op::Index { .. }) {
return Some(src);
}
if ctx.get_ranges(x).is_some() {
return Some(src);
}
None
}
fn convert_reduceaxis_with_context(x: &Arc<UOp>, ctx: &mut IndexingContext) -> Option<Arc<UOp>> {
let Op::ReduceAxis { src, reduce_op, axes } = x.op() else {
return None;
};
let (input_ranges, output_ranges) = ctx.get_ranges(x)?;
let reduce_ranges: SmallVec<[Arc<UOp>; 4]> =
input_ranges.iter().enumerate().filter(|(i, _)| axes.contains(i)).map(|(_, r)| Arc::clone(r)).collect();
let target = if reduce_ranges.is_empty() { Arc::clone(src) } else { src.reduce(reduce_ranges, *reduce_op) };
let target = if let Some(t) = x.tag() { target.with_tag(t.clone()) } else { target };
ctx.set_ranges(&target, input_ranges.clone(), output_ranges.clone());
if let Some(realize_axes) = ctx.get_realize_axes(x).cloned() {
ctx.mark_realize(&target, realize_axes);
}
Some(target)
}
#[tracing::instrument]
pub fn buffer_folding() -> TypedPatternMatcher {
crate::patterns! {
Bufferize { compute: c @ Const(_), .. } ~> |c| c.clone(),
Index { buffer: c @ Const(_), .. } ~> |c| c.clone(),
Copy { src: c @ Const(_), .. } ~> |c| c.clone(),
Index { buffer: buf @ Bufferize { compute, ranges, .. }, indices, gate: None }
if ranges_equal(ranges, indices) && !matches!(compute.op(), Op::BufferView { .. })
=> |compute, buf, ranges| {
let mut merged = SmallVec::<[usize; 2]>::new();
if let Some(t) = compute.tag() { merged.extend(t.iter().copied()); }
if let Some(t) = buf.tag() { merged.extend(t.iter().copied()); }
let tag = if merged.is_empty() { None } else { Some(merged) };
let result = compute.rtag(tag);
if !ranges.is_empty()
&& let (Ok(Some(buf_shape)), Ok(Some(_))) = (buf.shape(), result.shape()) {
let shrink_ranges: Vec<_> = buf_shape.iter()
.map(|s| (morok_ir::SInt::Const(0), s.clone()))
.collect();
if let Ok(shrunk) = result.try_shrink(&shrink_ranges) {
return Some(shrunk);
}
}
Some(result)
},
}
}
pub fn dead_axis_removal() -> TypedPatternMatcher {
crate::patterns! {
bufferize @ Bufferize { compute, ranges, opts } => |bufferize, compute, ranges, opts| {
cleanup_dead_axes_bufferize(bufferize, compute, ranges, opts)
},
}
}
fn cleanup_dead_axes_bufferize(
bufferize: &Arc<UOp>,
compute: &Arc<UOp>,
ranges: &SmallVec<[Arc<UOp>; 4]>,
opts: &BufferizeOpts,
) -> Option<Arc<UOp>> {
use morok_ir::SInt;
use morok_ir::shape::Shape;
if matches!(compute.op(), Op::Contiguous { .. } | Op::Copy { .. } | Op::Assign { .. }) {
return None;
}
let original_shape = bufferize.shape().ok().flatten()?;
let compute_ranges = compute.ranges();
let mut new_ranges = Vec::new();
let mut reshape_dims: Shape = SmallVec::new();
let mut had_dead = false;
for (i, range) in ranges.iter().enumerate() {
if let Op::Range { end, .. } = range.op()
&& !matches!(end.op(), Op::Const(_))
{
return None;
}
let is_const = matches!(range.op(), Op::Const(_));
let is_size_one = is_dead_axis(range);
let is_unused = matches!(range.op(), Op::Range { .. }) && !compute_ranges.iter().any(|r| Arc::ptr_eq(r, range));
if is_const || is_size_one || is_unused {
reshape_dims.push(SInt::Const(1)); had_dead = true;
} else {
new_ranges.push(Arc::clone(range));
if let Some(dim) = original_shape.get(i) {
reshape_dims.push(dim.clone());
} else {
return None; }
}
}
if !had_dead {
return None;
}
let reduced = UOp::bufferize(compute.clone(), new_ranges, opts.clone());
let reshaped = reduced.try_reshape(&reshape_dims).ok()?;
reshaped.try_expand(original_shape).ok()
}
pub fn buffer_removal() -> TypedPatternMatcher {
crate::patterns! {
buf @ Bufferize { compute } if unary_in_reduce_context(compute) => |buf| block_reduce_unary_inline(buf),
Bufferize { compute, .. } if is_cheap_to_inline(compute.op()) ~> |compute| compute.clone(),
Bufferize { compute: Bufferize { compute: inner, .. }, ranges, opts }
=> |inner, ranges, opts| Some(UOp::bufferize(Arc::clone(inner), ranges.to_vec(), opts.clone())),
}
}
pub fn buffer_removal_with_pcontig() -> TypedPatternMatcher<PcontigConfig> {
crate::patterns! {
@context PcontigConfig;
Index {
buffer: buffer @ Bufferize { compute: src, ranges: buf_ranges, .. },
indices: idx_ranges,
gate: None
} if idx_ranges.len() == 1
&& matches!(idx_ranges[0].op(), Op::Ternary(morok_ir::TernaryOp::Where, _, _, f) if matches!(f.op(), Op::Invalid))
=> |buffer, src, buf_ranges, idx_ranges, ctx| {
let Op::Ternary(morok_ir::TernaryOp::Where, cond, clean_idx, _) = idx_ranges[0].op() else {
unreachable!()
};
let clean_indices: SmallVec<[Arc<UOp>; 4]> = smallvec::smallvec![clean_idx.clone()];
let inlined = apply_pcontig_removal_inner(buffer, src, buf_ranges, &clean_indices, ctx)?;
let zero = UOp::const_(inlined.dtype(), ConstValue::zero(inlined.dtype().scalar()?));
UOp::try_where(cond.clone(), inlined, zero).ok()
},
Index {
buffer: buffer @ Bufferize { compute: src, ranges: buf_ranges, .. },
indices: idx_ranges,
gate: None
} => |buffer, src, buf_ranges, idx_ranges, ctx| {
apply_pcontig_removal_inner(buffer, src, buf_ranges, idx_ranges, ctx)
},
Index {
buffer: buffer @ Bufferize { compute: src, ranges: buf_ranges, .. },
indices: idx_ranges,
gate: Some(gate)
} => |buffer, src, buf_ranges, idx_ranges, gate, ctx| {
let inlined = apply_pcontig_removal_inner(buffer, src, buf_ranges, idx_ranges, ctx)?;
let zero = UOp::const_(inlined.dtype(), ConstValue::zero(inlined.dtype().scalar()?));
UOp::try_where(gate.clone(), inlined, zero).ok()
},
Bufferize { compute: compute @ Const(_), .. }
if ctx.level > 0
=> |compute| Some(compute.clone()),
Bufferize { compute: Bufferize { compute: inner, .. }, ranges, opts }
=> |inner, ranges, opts| Some(UOp::bufferize(Arc::clone(inner), ranges.to_vec(), opts.clone())),
}
}
#[allow(clippy::mutable_key_type)]
fn apply_pcontig_removal_inner(
buffer: &Arc<UOp>,
src: &Arc<UOp>,
buf_ranges: &SmallVec<[Arc<UOp>; 4]>,
idx_ranges: &SmallVec<[Arc<UOp>; 4]>,
config: &mut PcontigConfig,
) -> Option<Arc<UOp>> {
use morok_ir::{AddrSpace, AxisType, BufferizeOpts, ConstValue};
use std::collections::{HashMap, HashSet};
if config.level == 0 || is_always_run_op(src.op()) {
return None;
}
if let Op::Bufferize { opts, .. } = buffer.op()
&& !opts.removable
{
tracing::debug!(src_id = src.id, src_op = src.op().as_ref(), "buffer_removal: KEPT (non-removable)");
return None;
}
let mut accessed_buffers = Vec::new();
let mut indexes = Vec::new();
let mut reduces = Vec::new();
let mut visited = HashSet::new();
fn collect(
uop: &Arc<UOp>,
buffers: &mut Vec<Arc<UOp>>,
indexes: &mut Vec<Arc<UOp>>,
reduces: &mut Vec<Arc<UOp>>,
visited: &mut HashSet<UOpKey>,
) -> bool {
let key = UOpKey(Arc::clone(uop));
if !visited.insert(key) {
return true;
}
match uop.op() {
Op::Bufferize { opts, .. } if opts.addrspace == AddrSpace::Global => {
buffers.push(Arc::clone(uop));
return false; }
Op::MStack { .. } => {
buffers.push(Arc::clone(uop));
return false; }
Op::Param { .. } => {
buffers.push(Arc::clone(uop));
}
Op::Index { .. } => indexes.push(Arc::clone(uop)),
Op::Reduce { .. } => reduces.push(Arc::clone(uop)),
_ => {}
}
for child in uop.op().sources() {
collect(&child, buffers, indexes, reduces, visited);
}
true
}
collect(src, &mut accessed_buffers, &mut indexes, &mut reduces, &mut visited);
let mut seen = HashSet::new();
accessed_buffers.retain(|b| seen.insert(UOpKey(Arc::clone(b))));
if accessed_buffers.len() > config.max_buffers_threshold && config.level <= 2 {
tracing::debug!(
src_id = src.id,
src_op = src.op().as_ref(),
buf_count = accessed_buffers.len(),
threshold = config.max_buffers_threshold,
buf_ops = ?accessed_buffers.iter().map(|b| (b.id, b.op().as_ref())).collect::<Vec<_>>(),
"buffer_removal: KEPT (buffer count exceeds threshold)"
);
return None;
}
let buffer_in_reduce = if reduces.is_empty() {
false
} else {
let reduce_sources: Vec<Arc<UOp>> = reduces
.iter()
.filter_map(|r| if let Op::Reduce { src, .. } = r.op() { Some(Arc::clone(src)) } else { None })
.collect();
if reduce_sources.is_empty() {
false
} else {
let sink = UOp::sink(reduce_sources);
sink.any_in_subtree(|n| matches!(n.op(), Op::Param { .. } | Op::Bufferize { .. }))
}
};
if !buffer_in_reduce {
tracing::debug!(
src_id = src.id,
src_op = src.op().as_ref(),
buf_count = accessed_buffers.len(),
"buffer_removal: REMOVED (no buffer_in_reduce)"
);
let subs_map: HashMap<UOpKey, Arc<UOp>> = buf_ranges
.iter()
.zip(idx_ranges.iter())
.filter(|(k, _)| !matches!(k.op(), Op::Const(_)))
.map(|(k, v)| (UOpKey(Arc::clone(k)), Arc::clone(v)))
.collect();
return Some(src.substitute_gated(&subs_map));
}
if config.level <= 2 {
tracing::debug!(
src_id = src.id,
src_op = src.op().as_ref(),
reduce_count = reduces.len(),
buf_count = accessed_buffers.len(),
"buffer_removal: KEPT (buffer_in_reduce at level<=2)"
);
return None;
}
let output_size = match buffer.op() {
Op::Bufferize { ranges, .. } => {
let mut product = 1usize;
for range in ranges {
if let Op::Range { end, .. } = range.op()
&& let Op::Const(cv) = end.op()
&& let ConstValue::Int(n) = cv.0
&& n > 0
{
product = product.checked_mul(n as usize)?;
} else {
return None;
}
}
let element_size = buffer.dtype().base().bytes();
product.checked_mul(element_size)?
}
Op::Buffer { size, .. } => *size,
_ => return None,
};
let input_size: usize = accessed_buffers
.iter()
.filter_map(|buf| match buf.op() {
Op::Bufferize { ranges, .. } => {
let mut product = 1usize;
for range in ranges {
if let Op::Range { end, .. } = range.op()
&& let Op::Const(cv) = end.op()
&& let ConstValue::Int(n) = cv.0
&& n > 0
{
product = product.checked_mul(n as usize)?;
}
}
let elem_size = buf.dtype().base().bytes();
product.checked_mul(elem_size)
}
Op::Param { size, .. } => Some(*size),
Op::MStack { .. } => Some(1),
_ => None,
})
.sum();
let ratio = (output_size + 1) as f64 / (input_size + 1) as f64;
if ratio < config.out_in_ratio_threshold {
return None;
}
let local_indexes: Vec<_> = indexes
.iter()
.filter(|idx| {
matches!(idx.op(), Op::Index { buffer, .. }
if matches!(buffer.op(), Op::Bufferize { opts, .. }
if opts.addrspace == AddrSpace::Local))
})
.collect();
let mut exclude_ranges = HashSet::new();
for idx in &local_indexes {
if let Op::Index { indices, .. } = idx.op() {
for range in indices {
for r in range.in_scope_ranges() {
exclude_ranges.insert(r.clone());
}
}
}
}
let mut materialize = Vec::new();
let mut substitute = Vec::new();
for (buf_rng, idx_rng) in buf_ranges.iter().zip(idx_ranges.iter()) {
if matches!(buf_rng.op(), Op::Const(_)) {
continue;
}
let buf_key = UOpKey(Arc::clone(buf_rng));
let should_materialize = exclude_ranges.contains(&buf_key)
|| idx_rng.in_scope_ranges().iter().any(|r| {
if let Op::Range { axis_type, .. } = r.0.op() { matches!(axis_type, AxisType::Reduce) } else { false }
});
if should_materialize {
materialize.push((Arc::clone(buf_rng), Arc::clone(idx_rng)));
} else {
substitute.push((Arc::clone(buf_rng), Arc::clone(idx_rng)));
}
}
if substitute.is_empty() {
return None;
}
let subs_map: HashMap<UOpKey, Arc<UOp>> = substitute.into_iter().map(|(k, v)| (UOpKey(k), v)).collect();
let substituted = src.substitute_gated(&subs_map);
if materialize.is_empty() {
return Some(substituted);
}
let (mat_buf_rngs, mat_idx_rngs): (Vec<_>, Vec<_>) = materialize.into_iter().unzip();
let opts = BufferizeOpts::local();
let bufferized = UOp::bufferize(substituted, mat_buf_rngs, opts);
UOp::index().buffer(bufferized).indices(mat_idx_rngs).call().ok()
}
pub fn split_reduceop_patterns() -> TypedPatternMatcher<SplitReduceOpConfig> {
crate::patterns! {
@context SplitReduceOpConfig;
reduce @ ReduceAxis { src: _ } => |reduce, ctx| split_reduceop(reduce, ctx),
}
}
fn pm_reduce_unparented() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
reduce @ Reduce { src, ranges, reduce_op: Add | Mul | Max | Min } => |reduce, src, ranges, reduce_op| {
assert!(
ranges.iter().all(|r| matches!(r.op(), Op::Range { .. })),
"reduce_unparented: all reduce ranges must be RANGE ops, got: {:?}",
ranges.iter().map(|r| r.op().as_ref().to_string()).collect::<Vec<_>>()
);
#[allow(clippy::mutable_key_type)]
let src_ranges = src.in_scope_ranges();
let (parented, unparented) = partition_reduce_ranges(ranges, src_ranges);
if unparented.is_empty() {
return None;
}
let mut result = if !parented.is_empty() || reduce.dtype() != src.dtype() {
src.reduce(parented, *reduce_op)
} else {
Arc::clone(src)
};
match reduce_op {
ReduceOp::Add => {
for range in &unparented {
let size = get_range_size(range)?;
let size_casted = cast_to_dtype(&size, &result.dtype())?;
result = result.try_mul(&size_casted).ok()?;
}
}
ReduceOp::Mul => {
for range in &unparented {
let size = get_range_size(range)?;
let size_casted = cast_to_dtype(&size, &result.dtype())?;
result = result.try_pow(&size_casted).ok()?;
}
}
ReduceOp::Max | ReduceOp::Min => {}
}
Some(result)
},
}
}
#[allow(clippy::mutable_key_type)]
fn references_any_reduce_range(uop: &Arc<UOp>, ranges: &SmallVec<[Arc<UOp>; 4]>) -> bool {
let in_scope = uop.in_scope_ranges();
ranges.iter().any(|r| in_scope.contains(&UOpKey(r.clone())))
}
fn split_mul_factors(uop: &Arc<UOp>) -> SmallVec<[Arc<UOp>; 4]> {
match uop.op() {
Op::Binary(BinaryOp::Mul, a, b) => {
let mut factors = split_mul_factors(a);
factors.extend(split_mul_factors(b));
factors
}
_ => smallvec::smallvec![uop.clone()],
}
}
fn reduce_mul_chain(src: &Arc<UOp>, ranges: &SmallVec<[Arc<UOp>; 4]>, reduce_op: ReduceOp) -> Option<Arc<UOp>> {
let factors = split_mul_factors(src);
if factors.len() < 2 {
return None;
}
let mut inside: SmallVec<[Arc<UOp>; 4]> = SmallVec::new();
let mut outside: SmallVec<[Arc<UOp>; 4]> = SmallVec::new();
for factor in &factors {
if references_any_reduce_range(factor, ranges) {
inside.push(factor.clone());
} else {
if reduce_op == ReduceOp::Max {
let is_non_negative = match factor.vmin() {
ConstValue::Int(v) => *v >= 0,
ConstValue::UInt(_) => true,
ConstValue::Float(v) => *v >= 0.0,
ConstValue::Bool(_) => true,
};
if !is_non_negative {
inside.push(factor.clone());
continue;
}
}
outside.push(factor.clone());
}
}
if outside.is_empty() {
return None;
}
let inner = inside.into_iter().reduce(|a, b| a.mul(&b)).unwrap_or_else(|| src.const_like(1i64));
let reduced = inner.reduce(ranges.clone(), reduce_op);
let mut result = reduced;
for factor in &outside {
result = result.mul(factor);
}
Some(result)
}
pub fn pm_reduce_simplify() -> &'static TypedPatternMatcher {
static CACHED: std::sync::LazyLock<TypedPatternMatcher> = std::sync::LazyLock::new(|| {
pm_reduce_unparented()
+ crate::patterns! {
Reduce { src, ranges, reduce_op } if *reduce_op == ReduceOp::Add
=> |src, ranges| super::transforms::reduce_collapse(src, ranges),
Reduce { src, ranges, reduce_op }
if matches!(reduce_op, ReduceOp::Add | ReduceOp::Max)
&& matches!(src.op(), Op::Binary(BinaryOp::Mul, _, _))
=> |src, ranges, reduce_op| reduce_mul_chain(src, ranges, *reduce_op),
}
});
&CACHED
}
pub fn movement_op_patterns() -> TypedPatternMatcher {
crate::patterns! {
Index { buffer: mop, indices, gate } if mop.op().is_movement() => |mop, indices, gate| {
transform_movement_through_index(mop, indices, gate)
},
After { passthrough: mop, deps } if mop.op().is_movement()
=> |mop, deps| {
super::transforms::push_movement_through_after(mop, deps)
},
End { computation: mop, ranges } if mop.op().is_movement()
=> |mop, ranges| {
let src = &mop.op().sources()[0];
Some(src.end(ranges.clone()))
},
Index {
buffer: inner @ Index { indices: inner_indices, gate: None },
indices: outer_indices,
gate: None
} if inner_indices.len() == 1 && outer_indices.len() == 1
&& inner_indices[0].id == outer_indices[0].id
~> |inner| inner.clone(),
}
}
pub fn pm_syntactic_sugar() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
outer @ Index { buffer: inner @ Index { buffer: base_buffer, indices: inner_indices, gate: inner_gate }, indices: outer_indices, gate: outer_gate }
if matches!(inner.dtype(), DType::Ptr { .. }) && !matches!(outer.dtype(), DType::Ptr { .. })
=> |outer, inner, base_buffer, inner_indices, outer_indices, inner_gate, outer_gate| {
concat_index_indices(base_buffer, inner_indices, outer_indices, inner_gate, outer_gate, outer.dtype())
},
}
}
fn concat_index_indices(
base_buffer: &Arc<UOp>,
inner_indices: &SmallVec<[Arc<UOp>; 4]>,
outer_indices: &SmallVec<[Arc<UOp>; 4]>,
inner_gate: &Option<Arc<UOp>>,
outer_gate: &Option<Arc<UOp>>,
result_dtype: DType,
) -> Option<Arc<UOp>> {
let mut combined: SmallVec<[Arc<UOp>; 4]> = inner_indices.clone();
combined.extend(outer_indices.iter().cloned());
let combined_gate = match (inner_gate, outer_gate) {
(Some(g1), Some(g2)) => Some(g1.and_(g2)),
(Some(g), None) | (None, Some(g)) => Some(g.clone()),
(None, None) => None,
};
match combined_gate {
Some(g) => UOp::index().buffer(base_buffer.clone()).indices(combined).dtype(result_dtype).gate(g).call().ok(),
None => UOp::index().buffer(base_buffer.clone()).indices(combined).dtype(result_dtype).call().ok(),
}
}
pub(crate) fn transform_movement_through_index(
mop: &Arc<UOp>,
indices: &SmallVec<[Arc<UOp>; 4]>,
gate: &Option<Arc<UOp>>,
) -> Option<Arc<UOp>> {
use super::indexing::{SimplifyCache, apply_movement_op};
let src = &mop.op().sources()[0];
let src_shape = src.shape().ok()??;
let mut cache = SimplifyCache::default();
let transformed = apply_movement_op(mop.op(), src_shape, indices.as_slice(), &mut cache);
match gate {
Some(g) => UOp::index().buffer(src.clone()).indices(transformed).gate(g.clone()).call(),
None => UOp::index().buffer(src.clone()).indices(transformed).call(),
}
.ok()
}
fn dtype_zero(dtype: DType) -> Arc<UOp> {
let base = dtype.base();
let zero = ConstValue::zero(base);
if dtype.is_vector() {
UOp::vectorize((0..dtype.count()).map(|_| UOp::const_(DType::Scalar(base), zero)).collect())
} else {
UOp::const_(dtype, zero)
}
}
pub fn rangeify_codegen_patterns() -> TypedPatternMatcher<LocalAddBufferContext> {
crate::patterns! {
@context LocalAddBufferContext;
noop @ Noop() if noop.dtype().base() != morok_dtype::ScalarDType::Void => |noop, _ctx| {
Some(dtype_zero(noop.dtype()))
},
Contiguous { src, opts } => |src, opts, ctx| {
if !opts.is_empty() {
ctx.opts.extend(opts.iter().cloned());
}
Some(src.clone())
},
}
}
pub fn rangeify_codegen_simple() -> TypedPatternMatcher {
crate::patterns! {
noop @ Noop() if noop.dtype().base() != morok_dtype::ScalarDType::Void => |noop| {
Some(dtype_zero(noop.dtype()))
},
Contiguous { src, .. } => |src| {
Some(src.clone())
},
}
}
pub fn rangeify_codegen_with_kernel_ctx() -> TypedPatternMatcher<super::kernel::KernelContext> {
crate::patterns! {
@context super::kernel::KernelContext;
noop @ Noop() if noop.dtype().base() != morok_dtype::ScalarDType::Void => |noop, _ctx| {
Some(dtype_zero(noop.dtype()))
},
Contiguous { src, .. } => |src, _ctx| {
Some(src.clone())
},
}
}
fn extract_base_dtype(dtype: DType) -> DType {
match dtype {
DType::Ptr { base, .. } => (*base).clone(),
other => other,
}
}
fn extract_buffer_from_after(passthrough: &Arc<UOp>) -> Arc<UOp> {
match passthrough.op() {
Op::MStack { buffers } if !buffers.is_empty() => buffers[0].clone(),
Op::MSelect { buffer, .. } => buffer.clone(),
_ => passthrough.clone(),
}
}
fn find_kernel_output(ast: &Arc<UOp>) -> Option<Arc<UOp>> {
for node in ast.toposort() {
if let Some(buffer) = node.store_buffer() {
let output_buf = match buffer.op() {
Op::Index { buffer: inner_buf, .. } => inner_buf.clone(),
_ => buffer.clone(),
};
if matches!(output_buf.op(), Op::Param { device: None, .. }) {
return Some(output_buf);
}
}
}
None
}
pub fn to_param_patterns() -> TypedPatternMatcher<KernelContext> {
crate::patterns! {
@context KernelContext;
buf @ Buffer { size, unique: _ } => |buf, size, ctx| {
let ptr_dtype = extract_base_dtype(buf.dtype()).ptr(Some(*size), AddrSpace::Global);
let replacement = UOp::param(ctx.next_global(), *size, ptr_dtype, None);
ctx.map_buffer(buf.clone(), replacement.clone());
Some(replacement)
},
Bind { var, value } => |var, value, ctx| {
let bound_val = match value.op() {
Op::Const(cv) => cv.0.try_int(),
_ => None,
};
ctx.add_var(var.clone(), bound_val);
Some(var.clone())
},
after @ After { passthrough } => |after, passthrough, ctx| {
let buf = extract_buffer_from_after(passthrough);
if matches!(buf.dtype(), DType::Ptr { addrspace: AddrSpace::Local, .. }) {
return Some(buf);
}
ctx.map_buffer(buf.clone(), after.clone());
Some(buf)
},
c @ Const(_) | c @ DefineVar { name: _ } => |c, _ctx| {
let sources = c.op().sources();
if sources.is_empty() { return None; }
Some(match c.op() {
Op::Const(val) => UOp::const_(c.dtype(), val.0),
Op::DefineVar { name, min_val, max_val } => UOp::var(name.clone(), c.dtype(), *min_val, *max_val),
_ => return None,
})
},
Range { end } if matches!(end.op(), Op::Const(v) if v.0.is_zero()) => |_r, _ctx| {
Some(UOp::index_const(0))
},
Range { end, axis_id, axis_type } if matches!(axis_id, AxisId::Unrenumbered(_)) => |_r, end, axis_type, ctx| {
Some(UOp::range_axis(end.clone(), AxisId::Renumbered(ctx.next_range()), *axis_type))
},
Kernel { ast } => |_k, ast, _ctx| find_kernel_output(ast),
}
}
pub fn local_to_param_patterns() -> TypedPatternMatcher<LocalAddBufferContext> {
crate::patterns! {
@context LocalAddBufferContext;
buf @ Buffer { size, unique: _ } => |buf, size, ctx| {
let ptr_dtype = extract_base_dtype(buf.dtype()).ptr(Some(*size), AddrSpace::Global);
let replacement = UOp::param(ctx.next_param_slot(), *size, ptr_dtype, None);
if !ctx.has_buffer(buf) {
ctx.map_buffer(buf.clone(), buf.clone());
}
Some(replacement)
},
buf @ Param { slot: _, size } if matches!(buf.op(), Op::Param { device: Some(_), .. }) => |buf, size, ctx| {
let ptr_dtype = extract_base_dtype(buf.dtype()).ptr(Some(*size), AddrSpace::Global);
let replacement = UOp::param(ctx.next_param_slot(), *size, ptr_dtype, None);
if !ctx.has_buffer(buf) {
ctx.map_buffer(buf.clone(), buf.clone());
}
Some(replacement)
},
Bind { var, value } => |var, value, ctx| {
let bound_val = match value.op() {
Op::Const(cv) => cv.0.try_int(),
_ => None,
};
ctx.add_var(var.clone(), bound_val);
Some(var.clone())
},
after @ After { passthrough } => |after, passthrough, ctx| {
if matches!(passthrough.dtype(), DType::Ptr { addrspace: AddrSpace::Local, .. }) {
return None;
}
let buf = after.buf_uop();
let buf = match buf.op() {
Op::MStack { buffers } if !buffers.is_empty() => buffers[0].clone(),
Op::MSelect { buffer, .. } => buffer.clone(),
_ => buf,
};
if ctx.has_buffer(&buf) {
debug_assert!(false, "handle_after: duplicate buffer mapping for buf id={}", buf.id);
tracing::warn!(buf_id = buf.id, "handle_after: duplicate buffer mapping, skipping");
return None;
}
ctx.map_buffer(buf.clone(), after.clone());
Some(buf)
},
c @ Const(_) | c @ DefineVar { name: _ } => |c, _ctx| {
let sources = c.op().sources();
if sources.is_empty() { return None; }
Some(match c.op() {
Op::Const(val) => UOp::const_(c.dtype(), val.0),
Op::DefineVar { name, min_val, max_val } => UOp::var(name.clone(), c.dtype(), *min_val, *max_val),
_ => return None,
})
},
Range { end } if matches!(end.op(), Op::Const(v) if v.0.is_zero()) => |_r, _ctx| {
Some(UOp::index_const(0))
},
r @ Range { end: _, axis_id, axis_type }
if matches!(axis_id, AxisId::Unrenumbered(_)) && *axis_type == AxisType::Outer
=> |r, _axis_id, _axis_type, ctx| {
let vmin = r.vmin().try_int().unwrap_or(0);
let vmax = r.vmax().try_int().unwrap_or(0);
let range_id = ctx.next_range();
let var_name = format!("range_{range_id}");
let define_var = UOp::var(var_name, morok_dtype::DType::Index, vmin, vmax);
let Op::Range { end, axis_type, .. } = r.op() else { return None };
let range_cleared = UOp::range_axis(end.clone(), AxisId::Renumbered(range_id), *axis_type);
let bound = define_var.bind(range_cleared);
ctx.add_var(define_var, None);
Some(bound)
},
Range { end, axis_id, axis_type } if matches!(axis_id, AxisId::Unrenumbered(_)) => |_r, end, axis_type, ctx| {
Some(UOp::range_axis(end.clone(), AxisId::Renumbered(ctx.next_range()), *axis_type))
},
}
}
pub fn split_kernels_pattern() -> TypedPatternMatcher<Vec<Arc<UOp>>> {
use super::kernel::split_store;
crate::patterns! {
@context Vec<Arc<UOp>>;
x @ Store { index: _, value: _, .. } => |x, ctx| split_store(ctx, x),
x @ End { computation: _ } => |x, ctx| split_store(ctx, x),
}
}
#[allow(clippy::mutable_key_type)]
pub fn extract_device_from_graph(root: &Arc<UOp>) -> Option<DeviceSpec> {
let mut visited = HashSet::new();
fn visit(uop: &Arc<UOp>, visited: &mut HashSet<UOpKey>) -> Option<DeviceSpec> {
let key = UOpKey(Arc::clone(uop));
if !visited.insert(key) {
return None;
}
match uop.op() {
Op::Device(spec) => return Some(spec.clone()),
Op::Buffer { device, .. } => {
if let Op::Device(spec) = device.op() {
return Some(spec.clone());
}
}
Op::Bufferize { opts, .. } => {
if let Some(device_spec) = &opts.device {
return Some(device_spec.clone());
}
}
_ => {}
}
for child in uop.op().sources() {
if let Some(device) = visit(&child, visited) {
return Some(device);
}
}
None
}
visit(root, &mut visited)
}
#[allow(unused_variables)] pub fn buffer_limit_patterns(max_buffers: usize) -> TypedPatternMatcher<()> {
crate::patterns! {
for op in binary [*] {
tree@op(a, b) => |tree, a, b| {
check_buffer_limit(tree, &[a.clone(), b.clone()], max_buffers)
},
}
for op in ternary [*] {
tree@op(a, b, c) => |tree, a, b, c| {
check_buffer_limit(tree, &[a.clone(), b.clone(), c.clone()], max_buffers)
},
}
}
}
fn check_buffer_limit(tree: &Arc<UOp>, sources: &[Arc<UOp>], max_buffers: usize) -> Option<Arc<UOp>> {
let all_buffers = collect_accessed_buffers(sources);
if all_buffers.len() > max_buffers.saturating_sub(1) {
let mut any_changed = false;
let new_sources: Vec<_> = sources
.iter()
.map(|src| {
if is_elementwise(src) {
let new = force_bufferize(src);
if !Arc::ptr_eq(&new, src) {
any_changed = true;
}
new
} else {
src.clone()
}
})
.collect();
if any_changed {
return Some(tree.with_sources(new_sources));
}
}
None
}
fn collect_accessed_buffers(sources: &[Arc<UOp>]) -> Vec<Arc<UOp>> {
let mut all_buffers = Vec::new();
#[allow(clippy::mutable_key_type)]
let mut visited = HashSet::new();
#[allow(clippy::mutable_key_type)]
fn collect_recursive(uop: &Arc<UOp>, buffers: &mut Vec<Arc<UOp>>, visited: &mut HashSet<UOpKey>) {
let key = UOpKey(Arc::clone(uop));
if !visited.insert(key) {
return;
}
match uop.op() {
Op::Bufferize { opts, .. } if opts.addrspace == AddrSpace::Global => {
buffers.push(Arc::clone(uop));
return; }
Op::Buffer { .. } | Op::MStack { .. } | Op::MSelect { .. } => {
buffers.push(Arc::clone(uop));
}
_ => {}
}
for child in uop.op().sources() {
collect_recursive(&child, buffers, visited);
}
}
for src in sources {
collect_recursive(src, &mut all_buffers, &mut visited);
}
#[allow(clippy::mutable_key_type)]
let mut seen = HashSet::new();
all_buffers.retain(|b| seen.insert(UOpKey(Arc::clone(b))));
all_buffers
}
fn force_bufferize(src: &Arc<UOp>) -> Arc<UOp> {
let ranges = src.ranges().clone();
if ranges.is_empty() {
return Arc::clone(src);
}
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(Arc::clone(src), ranges.clone(), opts);
UOp::index().buffer(bufferized).indices(ranges).call().unwrap_or_else(|_| Arc::clone(src))
}
pub fn pm_add_loads() -> &'static TypedPatternMatcher<()> {
crate::cached_patterns! {
idx @ Index { buffer, indices } if !matches!(idx.dtype(), DType::Ptr { .. } | DType::Image { .. }) => |idx, buffer, indices| {
let result_dtype = idx.dtype().clone();
let gate = match idx.op() {
Op::Index { gate, .. } => gate.clone(),
_ => None,
};
let ptr_index = UOp::new(
Op::Index { buffer: buffer.clone(), indices: indices.clone(), gate },
buffer.dtype().clone(), );
Some(UOp::load().buffer(buffer.clone()).index(ptr_index).dtype(result_dtype).call())
},
Store { index: Load { index: real_index, .. }, value, ranges } =>
|real_index, value, ranges| {
Some(real_index.store_with_ranges(value.clone(), ranges.clone()))
},
}
}
fn is_vectorized_bool(dtype: &DType) -> bool {
dtype.base() == ScalarDType::Bool && dtype.vcount() > 1
}
fn devectorize_binary(op: &BinaryOp, result: &Arc<UOp>, a: &Arc<UOp>, b: &Arc<UOp>) -> Option<Arc<UOp>> {
let out_vcount = result.dtype().vcount();
if out_vcount <= 1 {
return None;
}
let a_vcount = a.dtype().vcount();
let b_vcount = b.dtype().vcount();
let scalar_dtype = result.dtype().scalar_dtype();
let scalar_ops: SmallVec<[Arc<UOp>; 4]> = (0..out_vcount)
.map(|i| {
let a_elem = if a_vcount > 1 { a.gep(vec![i]) } else { a.clone() };
let b_elem = if b_vcount > 1 { b.gep(vec![i]) } else { b.clone() };
UOp::new(Op::Binary(*op, a_elem, b_elem), scalar_dtype.clone())
})
.collect();
Some(UOp::vectorize(scalar_ops))
}
fn devectorize_unary(op: &UnaryOp, result: &Arc<UOp>, src: &Arc<UOp>) -> Option<Arc<UOp>> {
let out_vcount = result.dtype().vcount();
if out_vcount <= 1 {
return None;
}
let scalar_dtype = result.dtype().scalar_dtype();
let scalar_ops: SmallVec<[Arc<UOp>; 4]> = (0..out_vcount)
.map(|i| {
let elem = src.gep(vec![i]);
UOp::new(Op::Unary(*op, elem), scalar_dtype.clone())
})
.collect();
Some(UOp::vectorize(scalar_ops))
}
fn devectorize_generic(uop: &Arc<UOp>) -> Option<Arc<UOp>> {
let vcount = uop.dtype().vcount();
if vcount <= 1 {
return None;
}
let scalar_dtype = uop.dtype().scalar_dtype();
let sources = uop.op().sources();
let elements: SmallVec<[Arc<UOp>; 4]> = (0..vcount)
.map(|i| {
let new_sources: Vec<Arc<UOp>> =
sources.iter().map(|s| if s.dtype().vcount() > 1 { s.gep(vec![i]) } else { s.clone() }).collect();
match uop.op() {
Op::Cast { .. } => new_sources[0].cast(scalar_dtype.clone()),
Op::BitCast { .. } => new_sources[0].bitcast(scalar_dtype.clone()),
_ => uop.replace().dtype(scalar_dtype.clone()).src(new_sources).call(),
}
})
.collect();
Some(UOp::vectorize(elements))
}
pub fn pm_bool_devectorize() -> &'static TypedPatternMatcher<()> {
crate::cached_patterns! {
for op in binary [*] {
result @ op(a, b) if is_vectorized_bool(&result.dtype()) => |result, a, b| {
devectorize_binary(&op, result, a, b)
},
},
for op in unary [*] {
result @ op(src) if is_vectorized_bool(&result.dtype()) => |result, src| {
devectorize_unary(&op, result, src)
},
},
Where(cond, t, f) if cond.dtype().vcount() > 1 => |cond, t, f| {
devectorize_where(cond, t, f)
},
idx @ Index { buffer: _, .. } if is_vectorized_bool(&idx.dtype()) => devectorize_generic(idx),
c @ Cast { src: _, .. } if is_vectorized_bool(&c.dtype()) => devectorize_generic(c),
c @ Cast { src, .. } if is_vectorized_bool(&src.dtype()) => devectorize_generic(c),
bc @ BitCast { src: _, .. } if is_vectorized_bool(&bc.dtype()) => devectorize_generic(bc),
}
}
fn devectorize_where(cond: &Arc<UOp>, t: &Arc<UOp>, f: &Arc<UOp>) -> Option<Arc<UOp>> {
let vcount = cond.dtype().vcount();
if vcount <= 1 {
return None;
}
let t_vcount = t.dtype().vcount();
let f_vcount = f.dtype().vcount();
let scalar_wheres: SmallVec<[Arc<UOp>; 4]> = (0..vcount)
.map(|i| {
let cond_elem = cond.gep(vec![i]);
let t_elem = if t_vcount > 1 { t.gep(vec![i]) } else { t.clone() };
let f_elem = if f_vcount > 1 { f.gep(vec![i]) } else { f.clone() };
UOp::try_where(cond_elem, t_elem, f_elem).expect("WHERE construction should succeed")
})
.collect();
Some(UOp::vectorize(scalar_wheres))
}
fn apply_reduce_binary(reduce_op: ReduceOp, a: Arc<UOp>, b: Arc<UOp>, dtype: &DType) -> Arc<UOp> {
debug_assert!(
a.dtype() == b.dtype(),
"apply_reduce_binary: dtype mismatch between operands: a={:?}, b={:?}",
a.dtype(),
b.dtype()
);
match reduce_op {
ReduceOp::Add => UOp::new(Op::Binary(BinaryOp::Add, a, b), dtype.clone()),
ReduceOp::Mul => UOp::new(Op::Binary(BinaryOp::Mul, a, b), dtype.clone()),
ReduceOp::Max => UOp::new(Op::Binary(BinaryOp::Max, a, b), dtype.clone()),
ReduceOp::Min => {
let cond_dtype = DType::Bool.vec(dtype.vcount());
let cond = UOp::new(Op::Binary(BinaryOp::Lt, a.clone(), b.clone()), cond_dtype);
UOp::try_where(cond, a, b).unwrap()
}
}
}
fn horizontal_reduce(src: &Arc<UOp>, out_dtype: &DType, reduce_op: ReduceOp) -> Vec<Arc<UOp>> {
let src_count = src.dtype().vcount();
let out_count = out_dtype.vcount();
let horizontal_amount = src_count / out_count;
if !src_count.is_multiple_of(out_count) || horizontal_amount == 0 {
let scalar_dtype = src.dtype().scalar_dtype();
let elements: Vec<Arc<UOp>> = (0..src_count).map(|i| src.gep(vec![i])).collect();
return vec![
elements
.into_iter()
.reduce(|acc, elem| apply_reduce_binary(reduce_op, acc, elem, &scalar_dtype))
.expect("src_count >= 2 guaranteed by guard"),
];
}
(0..horizontal_amount)
.map(|i| {
let indices: Vec<usize> = (i..src_count).step_by(horizontal_amount).collect();
src.gep(indices)
})
.collect()
}
fn transform_vectorized_reduce(reduce: &Arc<UOp>) -> Option<Arc<UOp>> {
let Op::Reduce { src, ranges, reduce_op } = reduce.op() else {
return None;
};
let src_vcount = src.dtype().vcount();
let out_vcount = reduce.dtype().vcount();
if src_vcount <= out_vcount {
return None;
}
let out_dtype = reduce.dtype();
trace!(
src_vcount,
out_vcount,
reduce_op = ?reduce_op,
out_dtype = ?out_dtype,
"horizontal reducing vectorized REDUCE source"
);
let gep_list = horizontal_reduce(src, &out_dtype, *reduce_op);
let chained = gep_list
.into_iter()
.reduce(|acc, elem| apply_reduce_binary(*reduce_op, acc, elem, &out_dtype))
.expect("horizontal_reduce always returns non-empty list");
if ranges.is_empty() {
Some(chained)
} else {
Some(UOp::new(Op::Reduce { src: chained, ranges: ranges.clone(), reduce_op: *reduce_op }, out_dtype))
}
}
fn needs_reduce_devectorize(reduce: &Arc<UOp>) -> bool {
let Op::Reduce { src, .. } = reduce.op() else {
return false;
};
let src_vcount = src.dtype().vcount();
let out_vcount = reduce.dtype().vcount();
let is_bool = reduce.dtype().base() == ScalarDType::Bool;
let has_contract = matches!(src.op(), Op::Contract { .. });
has_contract && out_vcount > 1
|| out_vcount > 1 && is_bool && src_vcount == out_vcount
|| src_vcount > out_vcount && out_vcount > 1
}
#[inline]
fn is_k_vectorized(reduce: &Arc<UOp>, src: &Arc<UOp>) -> bool {
reduce.dtype().vcount() > 1 && matches!(src.op(), Op::Contract { .. })
}
#[inline]
fn is_bool_reduce(reduce: &Arc<UOp>, src: &Arc<UOp>) -> bool {
let out_vcount = reduce.dtype().vcount();
out_vcount > 1
&& reduce.dtype().base() == ScalarDType::Bool
&& src.dtype().vcount() == out_vcount
&& !matches!(src.op(), Op::Contract { .. })
}
pub fn pm_reduce_devectorize() -> &'static TypedPatternMatcher<()> {
crate::cached_patterns! {
reduce @ Reduce { src } if needs_reduce_devectorize(reduce) => |reduce, src| {
if is_k_vectorized(reduce, src) {
devectorize_to_scalar_accumulators(reduce)
} else if is_bool_reduce(reduce, src) {
devectorize_bool_reduce(reduce)
} else {
transform_vectorized_reduce(reduce)
}
},
}
}
fn devectorize_bool_reduce(reduce: &Arc<UOp>) -> Option<Arc<UOp>> {
let Op::Reduce { src, ranges, reduce_op } = reduce.op() else {
return None;
};
let vcount = reduce.dtype().vcount();
if vcount <= 1 {
return None;
}
let scalar_dtype = reduce.dtype().scalar_dtype();
trace!(
vcount,
reduce_op = ?reduce_op,
src_dtype = ?src.dtype(),
"devectorizing bool REDUCE to avoid <N x i1> accumulators"
);
let scalar_reduces: SmallVec<[Arc<UOp>; 4]> = (0..vcount)
.map(|i| {
let src_elem = src.gep(vec![i]);
UOp::new(Op::Reduce { src: src_elem, ranges: ranges.clone(), reduce_op: *reduce_op }, scalar_dtype.clone())
})
.collect();
Some(UOp::vectorize(scalar_reduces))
}
fn devectorize_to_scalar_accumulators(reduce: &Arc<UOp>) -> Option<Arc<UOp>> {
let Op::Reduce { src, ranges, reduce_op } = reduce.op() else {
return None;
};
let vec_count = reduce.dtype().vcount();
if vec_count <= 1 {
return None;
}
let vec_src = if let Op::Contract { src: inner, .. } = src.op() { inner.clone() } else { src.clone() };
let scalar_dtype = reduce.dtype().scalar_dtype();
trace!(
vec_count,
reduce_op = ?reduce_op,
src_dtype = ?vec_src.dtype(),
"devectorizing K-vectorized REDUCE to scalar accumulators"
);
let scalar_reduces: Vec<Arc<UOp>> = (0..vec_count)
.map(|i| {
let src_elem = vec_src.gep(vec![i]);
UOp::new(Op::Reduce { src: src_elem, ranges: ranges.clone(), reduce_op: *reduce_op }, scalar_dtype.clone())
})
.collect();
Some(tree_reduce(&scalar_reduces, *reduce_op, &scalar_dtype))
}
fn tree_reduce(elements: &[Arc<UOp>], reduce_op: ReduceOp, dtype: &DType) -> Arc<UOp> {
if elements.len() == 1 {
return elements[0].clone();
}
let mut level: Vec<Arc<UOp>> = elements.to_vec();
while level.len() > 1 {
let mut next_level = Vec::with_capacity(level.len().div_ceil(2));
for chunk in level.chunks(2) {
if chunk.len() == 2 {
next_level.push(apply_reduce_binary(reduce_op, chunk[0].clone(), chunk[1].clone(), dtype));
} else {
next_level.push(chunk[0].clone());
}
}
level = next_level;
}
level.remove(0)
}
pub fn pm_fma_decomposition() -> &'static TypedPatternMatcher<()> {
crate::cached_patterns! {
Add[Mul(a, b), c] if a.dtype().is_float() && a.dtype() == b.dtype() && a.dtype() == c.dtype() => |a, b, c| {
UOp::try_mulacc(a.clone(), b.clone(), c.clone()).ok()
},
}
}
fn no_range(u: &Arc<UOp>) -> bool {
!u.any_in_subtree(|x| matches!(x.op(), Op::Range { .. }))
}
fn no_load(u: &Arc<UOp>) -> bool {
!u.any_in_subtree(|x| matches!(x.op(), Op::Index { .. }))
}
fn is_const_zero(u: &Arc<UOp>) -> bool {
if let Op::Const(cv) = u.op() { cv.0.is_zero() } else { false }
}
fn uop_min(a: &Arc<UOp>, b: &Arc<UOp>) -> Option<Arc<UOp>> {
let neg_a = a.neg();
let neg_b = b.neg();
let max_neg = neg_a.try_max(&neg_b).ok()?;
Some(max_neg.neg())
}
fn gated_collapse_core(idx: &Arc<UOp>, range: &Arc<UOp>, end: &Arc<UOp>, expr: &Arc<UOp>) -> Option<Arc<UOp>> {
let idx_casted = idx.cast(range.dtype());
let zero = UOp::index_const(0);
let in_bounds = idx_casted.try_cmpge(&zero).ok()?.try_and_op(&idx_casted.try_cmplt(end).ok()?).ok()?;
let valid_idx = idx_casted.valid(in_bounds.clone());
#[allow(clippy::mutable_key_type)]
let subs: std::collections::HashMap<UOpKey, Arc<UOp>> = [(UOpKey(range.clone()), valid_idx)].into_iter().collect();
let substituted = expr.substitute(&subs);
let zero_like = UOp::const_(expr.dtype(), ConstValue::zero(expr.dtype().base()));
UOp::try_where(in_bounds, substituted, zero_like).ok()
}
fn try_reduce_collapse(
_reduce: &Arc<UOp>,
src: &Arc<UOp>,
ranges: &SmallVec<[Arc<UOp>; 4]>,
reduce_op: ReduceOp,
) -> Option<Arc<UOp>> {
if reduce_op != ReduceOp::Add {
return None;
}
if ranges.len() != 1 {
return None;
}
let range = &ranges[0];
let Op::Range { end, .. } = range.op() else {
return None;
};
let Op::Ternary(morok_ir::TernaryOp::Where, cond, true_val, false_val) = src.op() else {
return None;
};
if let Op::Binary(BinaryOp::Lt, lt_lhs, cut) = cond.op()
&& Arc::ptr_eq(lt_lhs, range)
&& is_const_zero(true_val)
&& no_range(false_val)
{
let zero = UOp::index_const(0);
let diff = end.try_sub(cut).ok()?;
let non_negative = diff.try_max(&zero).ok()?;
let count = uop_min(&non_negative, end)?;
let count_casted = count.cast(false_val.dtype());
return count_casted.try_mul(false_val).ok();
}
if let Op::Binary(BinaryOp::Lt, lt_lhs, cut) = cond.op()
&& Arc::ptr_eq(lt_lhs, range)
&& is_const_zero(false_val)
&& no_range(true_val)
{
let zero = UOp::index_const(0);
let clamped = cut.try_max(&zero).ok()?;
let count = uop_min(&clamped, end)?;
let count_casted = count.cast(true_val.dtype());
return count_casted.try_mul(true_val).ok();
}
if let Some(lower) = extract_ge_lower_bound(cond, range)
&& is_const_zero(false_val)
&& no_range(true_val)
&& no_range(&lower)
{
let zero = UOp::index_const(0);
let diff = end.try_sub(&lower).ok()?;
let non_negative = diff.try_max(&zero).ok()?;
let count = uop_min(&non_negative, end)?;
let count_casted = count.cast(true_val.dtype());
return count_casted.try_mul(true_val).ok();
}
if let Some(lower) = extract_ge_lower_bound(cond, range)
&& is_const_zero(true_val)
&& no_range(false_val)
&& no_range(&lower)
{
let zero = UOp::index_const(0);
let clamped = lower.try_max(&zero).ok()?;
let count = uop_min(&clamped, end)?;
let count_casted = count.cast(false_val.dtype());
return count_casted.try_mul(false_val).ok();
}
{
let (idx, cmp_range, expr) = match cond.op() {
Op::Binary(BinaryOp::Ne, idx, ne_range) if is_const_zero(true_val) && no_range(idx) => {
Some((idx, ne_range, false_val))
}
Op::Binary(BinaryOp::Eq, lhs, rhs) if is_const_zero(false_val) => {
if no_range(lhs) {
Some((lhs, rhs, true_val))
} else if no_range(rhs) {
Some((rhs, lhs, true_val))
} else {
None
}
}
_ => None,
}?;
let actual_range = if let Op::Cast { src, .. } = cmp_range.op() { src } else { cmp_range };
if Arc::ptr_eq(actual_range, range) {
return gated_collapse_core(idx, range, end, expr);
}
}
if let Op::Binary(BinaryOp::And, lhs_cond, rhs_cond) = cond.op()
&& is_const_zero(false_val)
&& no_range(true_val)
{
let lower_bound = extract_ge_lower_bound(lhs_cond, range).or_else(|| extract_ge_lower_bound(rhs_cond, range));
let upper_bound = extract_lt_upper_bound(rhs_cond, range).or_else(|| extract_lt_upper_bound(lhs_cond, range));
if let (Some(lower), Some(upper)) = (lower_bound, upper_bound)
&& no_range(&lower)
&& no_range(&upper)
{
let zero = UOp::index_const(0);
let clamped_upper = uop_min(&upper, end)?;
let clamped_lower = lower.try_max(&zero).ok()?;
let diff = clamped_upper.try_sub(&clamped_lower).ok()?;
let non_negative = diff.try_max(&zero).ok()?;
let count = uop_min(&non_negative, end)?;
let count_casted = count.cast(true_val.dtype());
return count_casted.try_mul(true_val).ok();
}
}
None
}
fn extract_ge_lower_bound(cond: &Arc<UOp>, range: &Arc<UOp>) -> Option<Arc<UOp>> {
if let Op::Unary(UnaryOp::Not, lt_cond) = cond.op()
&& let Op::Binary(BinaryOp::Lt, lt_lhs, lower) = lt_cond.op()
&& Arc::ptr_eq(lt_lhs, range)
{
return Some(lower.clone());
}
if let Op::Binary(BinaryOp::Ge, ge_lhs, lower) = cond.op()
&& Arc::ptr_eq(ge_lhs, range)
{
return Some(lower.clone());
}
None
}
fn extract_lt_upper_bound(cond: &Arc<UOp>, range: &Arc<UOp>) -> Option<Arc<UOp>> {
if let Op::Binary(BinaryOp::Lt, lt_lhs, upper) = cond.op()
&& Arc::ptr_eq(lt_lhs, range)
{
return Some(upper.clone());
}
None
}
fn try_define_var_factor(src: &Arc<UOp>, ranges: &SmallVec<[Arc<UOp>; 4]>) -> Option<Arc<UOp>> {
let Op::Ternary(morok_ir::TernaryOp::Where, cond, true_val, false_val) = src.op() else {
return None;
};
if !is_const_zero(false_val) {
return None;
}
let Op::Binary(BinaryOp::And, and_lhs, and_rhs) = cond.op() else {
return None;
};
let (define_var, other) = if matches!(and_lhs.op(), Op::DefineVar { .. }) {
(and_lhs.clone(), and_rhs.clone())
} else if matches!(and_rhs.op(), Op::DefineVar { .. }) {
(and_rhs.clone(), and_lhs.clone())
} else {
return None;
};
let inner_where = UOp::try_where(other, true_val.clone(), false_val.clone()).ok()?;
let inner_reduce = inner_where.reduce(ranges.clone(), ReduceOp::Add);
let casted_var = define_var.cast(true_val.dtype());
inner_reduce.try_mul(&casted_var).ok()
}
fn try_lift_arithmetic_from_lt(cond: &Arc<UOp>) -> Option<Arc<UOp>> {
let Op::Binary(BinaryOp::Lt, lhs, rhs) = cond.op() else {
return None;
};
if !no_range(rhs) {
return None;
}
let (inner_lhs, effective_rhs) = if let Op::Cast { src, .. } = lhs.op() {
let inner_dtype = src.dtype();
let casted_rhs = rhs.cast(inner_dtype);
(src.as_ref(), casted_rhs)
} else {
(lhs.as_ref(), rhs.clone())
};
if let Op::Binary(BinaryOp::Add, x, y) = inner_lhs.op()
&& no_range(y)
{
let new_rhs = effective_rhs.try_sub(y).ok()?;
return x.try_cmplt(&new_rhs).ok();
}
if let Op::Binary(BinaryOp::Mul, x, y) = inner_lhs.op()
&& no_range(y)
{
if let ConstValue::Int(ymin) = y.vmin()
&& *ymin > 0
{
let one = UOp::index_const(1);
let c_plus_y = effective_rhs.try_add(y).ok()?;
let c_plus_y_minus_1 = c_plus_y.try_sub(&one).ok()?;
let new_rhs = c_plus_y_minus_1.try_div(y).ok()?;
return x.try_cmplt(&new_rhs).ok();
}
}
None
}
fn try_lift_arithmetic_from_eq(cond: &Arc<UOp>) -> Option<Arc<UOp>> {
let Op::Binary(BinaryOp::Eq, raw_lhs, raw_rhs) = cond.op() else { return None };
let (lhs, rhs) = if no_range(raw_rhs) {
(raw_lhs, raw_rhs)
} else if no_range(raw_lhs) {
(raw_rhs, raw_lhs)
} else {
return None;
};
let (inner_lhs, effective_rhs) = if let Op::Cast { src, .. } = lhs.op() {
(src.as_ref(), rhs.cast(src.dtype()))
} else {
(lhs.as_ref(), rhs.clone())
};
match inner_lhs.op() {
Op::Binary(BinaryOp::Add, x, y) if no_range(y) => x.try_cmpeq(&effective_rhs.try_sub(y).ok()?).ok(),
Op::Binary(BinaryOp::Add, x, y) if no_range(x) => y.try_cmpeq(&effective_rhs.try_sub(x).ok()?).ok(),
Op::Binary(BinaryOp::Sub, x, y) if no_range(y) => x.try_cmpeq(&effective_rhs.try_add(y).ok()?).ok(),
Op::Binary(BinaryOp::Sub, x, y) if no_range(x) => y.try_cmpeq(&x.try_sub(&effective_rhs).ok()?).ok(),
_ => None,
}
}
fn try_lift_arithmetic_from_ge(cond: &Arc<UOp>) -> Option<Arc<UOp>> {
let Op::Binary(BinaryOp::Ge, lhs, rhs) = cond.op() else {
return None;
};
if !no_range(rhs) {
return None;
}
if let Op::Binary(BinaryOp::Add, x, y) = lhs.op() {
if no_range(y) {
return x.try_cmpge(&rhs.try_sub(y).ok()?).ok();
}
if no_range(x) {
return y.try_cmpge(&rhs.try_sub(x).ok()?).ok();
}
}
None
}
pub fn pm_load_collapse() -> &'static TypedPatternMatcher<()> {
crate::cached_patterns! {
_reduce @ Reduce { src, ranges, reduce_op }
if ranges.len() == 1 && *reduce_op == ReduceOp::Add
=> |src, ranges| {
super::transforms::reduce_load_collapse(src, ranges)
},
Lt(Add(x, y), c)
if x.dtype() == DType::Index && !no_load(x) && no_load(y) && no_load(c)
=> |x, y, c| {
let new_c = c.try_sub(y).ok()?;
x.try_cmplt(&new_c).ok()
},
}
}
pub fn build_reduce_collapse_matcher() -> &'static TypedPatternMatcher<()> {
static CACHED: std::sync::LazyLock<TypedPatternMatcher<()>> =
std::sync::LazyLock::new(|| reduce_collapse_inner_patterns() + crate::symbolic::symbolic());
&CACHED
}
pub fn build_reduce_load_collapse_matcher() -> &'static TypedPatternMatcher<()> {
static CACHED: std::sync::LazyLock<TypedPatternMatcher<()>> =
std::sync::LazyLock::new(|| build_reduce_collapse_matcher() + ne_lifting_patterns());
&CACHED
}
fn ne_lifting_patterns() -> TypedPatternMatcher<()> {
crate::patterns! {
Ne(Add(x, y), c) if no_range(y) && no_range(c) => |x, y, c| {
let new_c = c.try_sub(y).ok()?;
x.try_cmpne(&new_c).ok()
},
Ne(Cast { src: inner, .. }, c) if no_range(c) => |inner, c| {
let Op::Binary(BinaryOp::Add, x, y) = inner.op() else { return None };
if !no_range(y) { return None; }
let casted_c = c.cast(inner.dtype());
let new_c = casted_c.try_sub(y).ok()?;
x.try_cmpne(&new_c).ok()
},
}
}
fn reduce_collapse_inner_patterns() -> TypedPatternMatcher<()> {
pm_reduce_unparented().with_context()
+ crate::patterns! {
cond @ Lt(_, rhs) if no_range(rhs) => |cond| try_lift_arithmetic_from_lt(cond),
cond @ Ge(_, rhs) if no_range(rhs) => |cond| try_lift_arithmetic_from_ge(cond),
Reduce { src, ranges, reduce_op } if *reduce_op == ReduceOp::Add => |src, ranges| {
let Op::Binary(BinaryOp::Add, x, y) = src.op() else { return None };
let x_reduced = x.reduce(ranges.clone(), ReduceOp::Add);
let y_reduced = y.reduce(ranges.clone(), ReduceOp::Add);
x_reduced.try_add(&y_reduced).ok()
},
reduce @ Reduce { src, ranges, reduce_op }
if !ranges.is_empty() && *reduce_op == ReduceOp::Add
=> |reduce, src, ranges| {
try_reduce_collapse(reduce, src, ranges, ReduceOp::Add)
.or_else(|| try_define_var_factor(src, ranges))
},
Mul[x, Cast { src: gate, .. }] if gate.dtype() == DType::Bool => |x, gate| {
let zero = UOp::const_(x.dtype(), ConstValue::zero(x.dtype().base()));
UOp::try_where(gate.clone(), x.clone(), zero).ok()
},
cond @ Eq[_, c] if no_range(c) => |cond| try_lift_arithmetic_from_eq(cond),
}
}
pub fn pm_mod_to_and() -> &'static TypedPatternMatcher<()> {
use morok_ir::types::ConstValue;
crate::cached_patterns! {
Mod(x, _c @const(c_val)) => |x, c_val| {
if !x.dtype().is_int() { return None; }
let n = match c_val {
ConstValue::Int(v) if v > 0 && (v as u64).is_power_of_two() => v,
ConstValue::UInt(v) if v > 0 && v.is_power_of_two() => v as i64,
_ => return None,
};
let mask = UOp::const_(x.dtype(), ConstValue::Int(n - 1));
x.try_and_op(&mask).ok()
},
}
}
pub fn pm_mul_to_shl() -> &'static TypedPatternMatcher<()> {
use morok_ir::types::ConstValue;
crate::cached_patterns! {
Mul[x, _c @const(c_val)] => |x, c_val| {
if !x.dtype().is_int() { return None; }
let (n, shift) = match c_val {
ConstValue::Int(v) if v > 0 && (v as u64).is_power_of_two() => (v as u64, (v as u64).trailing_zeros()),
ConstValue::UInt(v) if v > 0 && v.is_power_of_two() => (v, v.trailing_zeros()),
_ => return None,
};
if n == 1 { return Some(x.clone()); } let shift_amount = UOp::const_(x.dtype(), ConstValue::Int(shift as i64));
x.try_shl_op(&shift_amount).ok()
},
}
}
pub fn pm_neg_from_mul() -> &'static TypedPatternMatcher<()> {
crate::cached_patterns! {
Mul[x, _c @const(c_val)] if c_val.is_neg_one() => |x| {
let dtype = x.dtype();
Some(UOp::new(Op::Unary(UnaryOp::Neg, x.clone()), dtype))
},
Add[x, Neg(y)] ~> x.sub(y),
}
}
pub fn pm_threefry_decomp() -> &'static TypedPatternMatcher<()> {
crate::cached_patterns! {
Threefry(x, key) if x.dtype() == DType::UInt64 => |x, key| {
Some(threefry2x32(x, key))
},
}
}
fn threefry2x32(x: &Arc<UOp>, key: &Arc<UOp>) -> Arc<UOp> {
let u32_dt = DType::Scalar(morok_dtype::ScalarDType::UInt32);
let u64_dt = DType::Scalar(morok_dtype::ScalarDType::UInt64);
let mask32 = UOp::const_(u64_dt.clone(), ConstValue::UInt(0xFFFFFFFF));
let pow32 = UOp::const_(u64_dt.clone(), ConstValue::UInt(1u64 << 32));
let x0 = x.and_(&mask32).cast(u32_dt.clone());
let x1 = x.idiv(&pow32).and_(&mask32).cast(u32_dt.clone());
let key0 = key.and_(&mask32).cast(u32_dt.clone());
let key1 = key.idiv(&pow32).and_(&mask32).cast(u32_dt.clone());
let skein_const = UOp::const_(u32_dt.clone(), ConstValue::UInt(0x1BD11BDA));
let ks = [key1.clone(), key0.xor(&key1).xor(&skein_const), key0.clone()];
let rotations: [[u32; 4]; 2] = [[13, 15, 26, 6], [17, 29, 16, 24]];
let mut xr0 = x0.add(&ks[2]);
let mut xr1 = x1.add(&ks[0]);
for i in 0..5u32 {
for &r in &rotations[i as usize % 2] {
let new_x0 = xr0.add(&xr1);
let rot_left = xr1.mul(&UOp::const_(u32_dt.clone(), ConstValue::UInt(1u64 << r)));
let rot_right = xr1.idiv(&UOp::const_(u32_dt.clone(), ConstValue::UInt(1u64 << (32 - r))));
let rotated = rot_left.add(&rot_right);
xr1 = new_x0.xor(&rotated);
xr0 = new_x0;
}
xr0 = xr0.add(&ks[i as usize % 3]);
let round_const = UOp::const_(u32_dt.clone(), ConstValue::UInt((i + 1) as u64));
xr1 = xr1.add(&ks[(i as usize + 1) % 3]).add(&round_const);
}
xr1.cast(u64_dt.clone()).mul(&pow32).or_(&xr0.cast(u64_dt))
}
pub fn pm_demorgan() -> &'static TypedPatternMatcher<()> {
crate::cached_patterns! {
And[Not(x), Not(y)] if x.dtype().is_bool() ~> x.or_(y).not(),
}
}
pub fn pm_shl_add_to_mulacc() -> &'static TypedPatternMatcher<()> {
crate::cached_patterns! {
Add[Shl(x, _n @const(nv)), c] => |x, nv, c| {
let ConstValue::Int(v) = nv else { return None };
if !(0..64).contains(&v) { return None; }
let multiplier = UOp::const_(x.dtype(), ConstValue::Int(1i64 << v));
UOp::try_mulacc(x.clone(), multiplier, c.clone()).ok()
},
}
}
pub fn pm_div_to_shr() -> &'static TypedPatternMatcher<()> {
use morok_ir::types::ConstValue;
use morok_ir::uop::cached_property::CachedProperty;
use morok_ir::uop::properties::VminVmaxProperty;
crate::cached_patterns! {
Idiv(x, _c @const(c_val)) => |x, c_val| {
if !x.dtype().is_int() { return None; }
let n = match c_val {
ConstValue::Int(v) if v > 0 && (v as u64).is_power_of_two() => v,
ConstValue::UInt(v) if v > 0 && v.is_power_of_two() => v as i64,
_ => return None,
};
if n == 1 { return None; }
let shift = (n as u64).trailing_zeros() as i64;
let shift_const = UOp::const_(x.dtype(), ConstValue::Int(shift));
let (vmin, _) = VminVmaxProperty::get(x);
let is_non_negative = match vmin {
ConstValue::Int(v) => *v >= 0,
ConstValue::UInt(_) => true, _ => false,
};
if is_non_negative || x.dtype().is_unsigned() {
x.try_shr_op(&shift_const).ok()
} else {
let zero = UOp::const_(x.dtype(), ConstValue::Int(0));
let bias = UOp::const_(x.dtype(), ConstValue::Int(n - 1));
let x_neg = x.try_cmplt(&zero).ok()?;
let adjustment = UOp::try_where(x_neg, bias, zero).ok()?;
let adjusted = x.try_add(&adjustment).ok()?;
adjusted.try_shr_op(&shift_const).ok()
}
},
}
}
pub fn pm_max_decomposition() -> &'static TypedPatternMatcher<()> {
crate::cached_patterns! {
Max(a, b) => |a, b| {
let cond = a.try_cmplt(b).ok()?;
UOp::try_where(cond, b.clone(), a.clone()).ok()
},
}
}
pub fn pm_sqrt_decomposition() -> &'static TypedPatternMatcher<()> {
crate::cached_patterns! {
Sqrt(x) if x.dtype().is_float() => |x| {
let half = UOp::const_(x.dtype(), morok_ir::types::ConstValue::Float(0.5));
x.try_pow(&half).ok()
},
}
}
pub fn pm_erf_decomposition() -> &'static TypedPatternMatcher<()> {
crate::cached_patterns! {
Erf(x) if x.dtype().is_float() => |x| {
let dt = x.dtype();
let f = |v: f64| UOp::const_(dt.clone(), ConstValue::Float(v));
let abs_x = x.abs();
let t = f(1.0).try_div(&f(1.0).try_add(&f(0.3275911).try_mul(&abs_x).ok()?).ok()?).ok()?;
let poly = f(1.061405429);
let poly = poly.try_mul(&t).ok()?.try_add(&f(-1.453152027)).ok()?;
let poly = poly.try_mul(&t).ok()?.try_add(&f(1.421413741)).ok()?;
let poly = poly.try_mul(&t).ok()?.try_add(&f(-0.284496736)).ok()?;
let poly = poly.try_mul(&t).ok()?.try_add(&f(0.254829592)).ok()?;
let exp_val = x.square().neg().try_exp().ok()?;
let inner = f(1.0).try_sub(&t.try_mul(&poly).ok()?.try_mul(&exp_val).ok()?).ok()?;
x.sign().try_mul(&inner).ok()
},
}
}
pub fn pm_fdiv_to_mul() -> &'static TypedPatternMatcher<()> {
use morok_ir::types::ConstValue;
crate::cached_patterns! {
Fdiv(x, _c @const(c_val)) => |x, c_val| {
if !x.dtype().is_float() { return None; }
let f = match c_val {
ConstValue::Float(v) => v,
_ => return None,
};
if f == 0.0 { return None; }
let recip = 1.0 / f;
if !recip.is_finite() { return None; }
let recip_const = UOp::const_(x.dtype(), ConstValue::Float(recip));
x.try_mul(&recip_const).ok()
},
}
}
pub fn pm_comparison_negations() -> &'static TypedPatternMatcher<()> {
use morok_ir::types::ConstValue;
crate::cached_patterns! {
Not(Lt(x, _c @const(c_val))) if x.dtype().is_int() => |x, c_val| {
let v = match c_val {
ConstValue::Int(v) => v,
ConstValue::UInt(v) => i64::try_from(v).ok()?,
_ => return None,
};
let c_minus_1 = v.checked_sub(1)?;
let c_minus_1_const = UOp::const_(x.dtype(), ConstValue::Int(c_minus_1));
c_minus_1_const.try_cmplt(x).ok()
},
Not(Lt(_c @const(c_val), x)) if x.dtype().is_int() => |x, c_val| {
let v = match c_val {
ConstValue::Int(v) => v,
ConstValue::UInt(v) => i64::try_from(v).ok()?,
_ => return None,
};
let c_plus_1 = v.checked_add(1)?;
let c_plus_1_const = UOp::const_(x.dtype(), ConstValue::Int(c_plus_1));
x.try_cmplt(&c_plus_1_const).ok()
},
And[Lt(_c1 @const(c1_val), x), Lt(x2, _c2 @const(c2_val))]
if x.dtype().is_int() && Arc::ptr_eq(x, x2)
=> |x, c1_val, c2_val| {
let v1 = match c1_val {
ConstValue::Int(v) => v,
ConstValue::UInt(v) => i64::try_from(v).ok()?,
_ => return None,
};
let v2 = match c2_val {
ConstValue::Int(v) => v,
ConstValue::UInt(v) => i64::try_from(v).ok()?,
_ => return None,
};
if v2 != v1.checked_add(2)? { return None; }
let target = UOp::const_(x.dtype(), ConstValue::Int(v1 + 1));
x.try_cmpeq(&target).ok()
},
Lt(Mul(x, _neg1 @const(neg_val)), _c @const(c_val)) if x.dtype().is_int() => |x, neg_val, c_val| {
if !matches!(neg_val, ConstValue::Int(-1)) { return None; }
let c = match c_val {
ConstValue::Int(v) => v,
ConstValue::UInt(v) => i64::try_from(v).ok()?,
_ => return None,
};
let neg_c = c.checked_neg()?;
let neg_c_const = UOp::const_(x.dtype(), ConstValue::Int(neg_c));
neg_c_const.try_cmplt(x).ok()
},
Lt(Mul(x, _neg1 @const(neg_val)), Mul(y, _c @const(c_val))) if x.dtype().is_int() => |x, neg_val, y, c_val| {
if !matches!(neg_val, ConstValue::Int(-1)) { return None; }
let c = match c_val {
ConstValue::Int(v) => v,
ConstValue::UInt(v) => i64::try_from(v).ok()?,
_ => return None,
};
let neg_c = c.checked_neg()?;
let neg_c_const = UOp::const_(y.dtype(), ConstValue::Int(neg_c));
let y_neg_c = y.try_mul(&neg_c_const).ok()?;
y_neg_c.try_cmplt(x).ok()
},
}
}