use morok_dtype::{DType, ScalarDType};
use morok_ir::types::{BinaryOp, ConstValue, ConstValueHash};
use morok_ir::uop::cached_property::CachedProperty;
use morok_ir::uop::comparison_analysis::ComparisonAnalyzer;
use morok_ir::uop::eval::{
eval_add_typed, eval_binary_op, eval_binary_op_broadcast, eval_binary_op_broadcast_typed, eval_mul_typed,
eval_sub_typed, eval_unary_op_vec_typed,
};
use morok_ir::uop::properties::VminVmaxProperty;
use morok_ir::{IntoUOp, Op, UOp};
use crate::TypedPatternMatcher;
use crate::rangeify::indexing::get_const_value;
use crate::symbolic::dce::is_empty_range;
use smallvec::SmallVec;
use std::sync::Arc;
use tracing::trace;
pub fn constant_folding_dsl_patterns() -> &'static TypedPatternMatcher {
use morok_ir::uop::eval::{eval_binary_op_typed, eval_ternary_op_typed, eval_unary_op_typed};
crate::cached_patterns! {
for op in unary [Sqrt, Exp2, Log2, Sin, Reciprocal, Trunc] {
op(c @const(c_val))
=> eval_unary_op_typed(op, c_val, c.dtype().base()).map(|r| UOp::const_(c.dtype(), r)),
},
for op in binary [Add, Mul, Sub, Mod, Max, Pow, Idiv, Fdiv, And, Or, Xor, Shl, Shr] {
op(a @const(a_val), _b @const(b_val))
=> eval_binary_op_typed(op, a_val, b_val, a.dtype().base()).map(|r| UOp::const_(a.dtype(), r)),
},
for op in ternary [Where, MulAcc] {
op(_a @const(a_val), b @const(b_val), _c @const(c_val))
=> eval_ternary_op_typed(op, a_val, b_val, c_val, b.dtype().base()).map(|r| UOp::const_(b.dtype(), r)),
},
}
}
pub fn vconst_folding_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
for op in binary [Add, Mul, Sub, Mod, Max, Idiv, And, Or, Xor, Shl, Shr] {
op(a @vconst(vals_a), _b @vconst(vals_b))
=> {
let dt = a.dtype().scalar_dtype();
eval_binary_op_broadcast_typed(op, &vals_a, &vals_b, a.dtype().base())
.map(|v| UOp::vconst(v, dt))
},
},
for op in binary [Lt, Le, Eq, Ne, Gt, Ge] {
op(_a @vconst(vals_a), _b @vconst(vals_b))
=> {
eval_binary_op_broadcast(op, &vals_a, &vals_b)
.map(|v| UOp::vconst(v, DType::Bool))
},
},
for op in binary [Add, Mul, Sub, Mod, Max, Idiv, And, Or, Xor, Shl, Shr] {
op(a @anyconst(vals_a), _b @anyconst(vals_b))
if vals_a.len() != vals_b.len()
=> {
let dt = a.dtype().scalar_dtype();
eval_binary_op_broadcast_typed(op, &vals_a, &vals_b, a.dtype().base()).map(|v| UOp::vconst(v, dt))
},
},
for op in binary [Lt, Le, Eq, Ne, Gt, Ge] {
op(_a @anyconst(vals_a), _b @anyconst(vals_b))
if vals_a.len() != vals_b.len()
=> eval_binary_op_broadcast(op, &vals_a, &vals_b).map(|v| UOp::vconst(v, DType::Bool)),
},
for op in unary [Sqrt, Exp2, Log2, Sin, Reciprocal, Trunc] {
op(a @vconst(vals))
=> {
let dt = a.dtype().scalar_dtype();
eval_unary_op_vec_typed(op, &vals, a.dtype().base()).map(|v| UOp::vconst(v, dt))
},
},
}
}
pub fn vectorize_to_vconst_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Vectorize { elements } if !elements.is_empty() && elements.iter().all(|e| matches!(e.op(), Op::Const(_))) => {
let scalar_dt = elements[0].dtype();
let values: Vec<ConstValue> = elements.iter().filter_map(|e| {
if let Op::Const(cv) = e.op() { Some(cv.0) } else { None }
}).collect();
if values.len() == elements.len() { Some(UOp::vconst(values, scalar_dt)) } else { None }
},
}
}
pub fn bool_arithmetic_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Mul[x, y] if x.dtype() == DType::Bool && y.dtype() == DType::Bool ~> x.and_(y),
Add[x, y] if x.dtype() == DType::Bool && y.dtype() == DType::Bool ~> x.or_(y),
Max(x, y) if x.dtype() == DType::Bool && y.dtype() == DType::Bool ~> x.or_(y),
}
}
pub fn identity_and_zero_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Add[x, @zero] ~> x.clone(),
Mul[x, @one] ~> x.clone(),
Or[x, @zero] ~> x.clone(),
Xor[x, @zero] ~> x.clone(),
Sub(x, @zero) ~> x.clone(),
Idiv(x, @one) ~> x.clone(),
Fdiv(x, @one) ~> x.clone(),
Mod(x, @one) => x.dtype().scalar().map(|dt| UOp::const_(x.dtype(), ConstValue::zero(dt))),
for op in unary [Floor, Ceil, Trunc, Round] {
op(x) if !x.dtype().is_float() ~> { let _ = op; x.clone() }
},
Mul[x, _zero @ @zero] => {
if let Op::Const(ConstValueHash(ConstValue::Float(f))) = x.op()
&& (f.is_nan() || f.is_infinite()) {
return Some(UOp::const_(x.dtype(), ConstValue::Float(f64::NAN)));
}
x.dtype().scalar().map(|dt| UOp::const_(x.dtype(), ConstValue::zero(dt)))
},
And[_, zero @ @zero] ~> zero.clone(),
}
}
pub fn propagate_invalid() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Where(cond, inv, x) if matches!(inv.op(), Op::Invalid) => {
let invalid = if inv.dtype() == x.dtype() { inv.clone() } else { UOp::new(Op::Invalid, x.dtype()) };
let flipped = match cond.op() {
Op::Unary(morok_ir::UnaryOp::Not, inner) => Arc::clone(inner),
_ => cond.not(),
};
UOp::try_where(flipped, x.clone(), invalid).ok()
},
Where(c1, Where(c2, x, d), d) ~> {
let combined = c1.and_(c2);
UOp::try_where(combined, x.clone(), d.clone()).expect("failed to create WHERE")
},
Cast { src: Where(_cond, x, invalid), dtype } if matches!(invalid.op(), Op::Invalid) ~> x.cast(dtype.clone()),
for op in binary [Add, Mul, Sub, Mod, Max, Idiv, Fdiv, Pow, And, Or, Xor, Shl, Shr] {
r @ op(Where(cond, x, invalid), y)
if matches!(invalid.op(), Op::Invalid)
~> {
let inner = UOp::new(Op::Binary(op, x.clone(), y.clone()), r.dtype());
UOp::try_where(cond.clone(), inner, invalid.clone()).expect("failed to create WHERE")
},
},
for op in binary [Add, Mul, Sub, Mod, Max, Idiv, Fdiv, Pow, And, Or, Xor, Shl, Shr] {
r @ op(y, Where(cond, x, invalid))
if matches!(invalid.op(), Op::Invalid)
~> {
let inner = UOp::new(Op::Binary(op, y.clone(), x.clone()), r.dtype());
UOp::try_where(cond.clone(), inner, invalid.clone()).expect("failed to create WHERE")
},
},
for op in binary [Lt, Le, Eq, Ne, Gt, Ge] {
r @ op(Where(_cond, x, invalid), y)
if matches!(invalid.op(), Op::Invalid)
~> UOp::new(Op::Binary(op, x.clone(), y.clone()), r.dtype()),
},
for op in binary [Lt, Le, Eq, Ne, Gt, Ge] {
r @ op(y, Where(_cond, x, invalid))
if matches!(invalid.op(), Op::Invalid)
~> UOp::new(Op::Binary(op, y.clone(), x.clone()), r.dtype()),
},
for op in binary [Add, Mul, Sub, Mod, Max, Idiv, Fdiv, Pow, And, Or, Xor, Shl, Shr] {
op(invalid, y) if matches!(invalid.op(), Op::Invalid) && y.dtype() == DType::Index
~> { let _ = op; invalid.clone() },
},
for op in binary [Add, Mul, Max, And, Or, Xor] {
op(y, invalid) if matches!(invalid.op(), Op::Invalid) && y.dtype() == DType::Index
~> { let _ = op; invalid.clone() },
},
}
}
pub fn fold_invalid_load_store() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
load @ Load { index: Index { indices, .. }, .. }
if indices.iter().any(|idx| matches!(idx.op(), Op::Invalid))
=> {
let zero = ConstValue::zero(load.dtype().scalar()?);
Some(load.const_like(zero))
},
load @ Load { index: Cast { src: Index { indices, .. }, .. }, .. }
if indices.iter().any(|idx| matches!(idx.op(), Op::Invalid))
=> {
let zero = ConstValue::zero(load.dtype().scalar()?);
Some(load.const_like(zero))
},
Store { index: Index { indices, .. }, .. }
if indices.iter().any(|idx| matches!(idx.op(), Op::Invalid))
~> UOp::new(Op::Noop, DType::Void),
Store { index: Cast { src: Index { indices, .. }, .. }, .. }
if indices.iter().any(|idx| matches!(idx.op(), Op::Invalid))
~> UOp::new(Op::Noop, DType::Void),
}
}
pub fn symbolic_simple() -> &'static TypedPatternMatcher {
static CACHED: std::sync::LazyLock<TypedPatternMatcher> = std::sync::LazyLock::new(|| {
propagate_invalid()
+ constant_folding_dsl_patterns()
+ vconst_folding_patterns() + bool_arithmetic_patterns()
+ identity_and_zero_patterns()
+ self_folding_dsl_patterns()
+ zero_folding_dsl_patterns()
+ division_dsl_patterns()
+ cast_dsl_patterns()
+ div_mod_recombine_dsl_patterns()
+ power_dsl_patterns()
+ boolean_dsl_simple_patterns()
+ dce_dsl_simple_patterns()
+ dead_loop_patterns()
});
&CACHED
}
pub fn symbolic() -> &'static TypedPatternMatcher {
static CACHED: std::sync::LazyLock<TypedPatternMatcher> = std::sync::LazyLock::new(|| {
symbolic_simple()
+ commutative_canonicalization()
+ boolean_dsl_patterns() + term_combining_dsl_patterns() + dce_dsl_patterns() + where_alu_combining_patterns() + vmin_vmax_collapse_patterns() + minmax_dsl_patterns() + alu_folding_dsl_patterns() + comparison_dsl_patterns() + range_based_mod_div_patterns() + advanced_division_dsl_patterns() + range_based_cast_patterns() + long_to_int_narrowing_patterns() + vectorize_to_vconst_patterns() + after_simplification_patterns() + where_bound_patterns() + gep_pushing_patterns() });
&CACHED
}
pub fn sym() -> &'static TypedPatternMatcher {
static CACHED: std::sync::LazyLock<TypedPatternMatcher> = std::sync::LazyLock::new(|| {
symbolic()
+ super::valid_simplification::pm_simplify_valid()
+ alu_vectorize_reorder_patterns()
+ ne_zero_fold_patterns()
+ cast_where_dsl_patterns()
+ fold_invalid_load_store()
+ store_load_folding_patterns()
+ reciprocal_patterns()
+ reduce_sym_patterns()
+ sym_phase3_patterns()
});
&CACHED
}
fn commutative_canonicalization() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
for op in binary [Add, Mul, Max, Eq, Ne, And, Or, Xor] {
r @ op(a, b)
if r.dtype() == DType::Index && b.id < a.id
~> UOp::new(Op::Binary(op, b.clone(), a.clone()), r.dtype()),
},
}
}
pub fn self_folding_dsl_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Idiv(x, x) ~> 1.into_uop(x.dtype()),
Idiv(x, _c @const(c_val)) if c_val.is_neg_one() ~> x.neg(),
Mod(Mod(x, y), y) ~> x.mod_(y),
And(x, x) ~> x.clone(),
Max(x, x) ~> x.clone(),
Or(x, x) ~> x.clone(),
}
}
pub fn zero_folding_dsl_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Mod(x, x) => x.dtype().scalar().map(|dt| UOp::const_(x.dtype(), ConstValue::zero(dt))),
Lt(x, x) => Some(UOp::const_(DType::Bool.vec(x.dtype().vcount()), ConstValue::Bool(false))),
Ne(x, x) if x.dtype().is_int() || x.dtype().is_bool() =>
Some(UOp::const_(DType::Bool.vec(x.dtype().vcount()), ConstValue::Bool(false))),
}
}
pub fn range_based_mod_div_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Mod(range @ Range { end, .. }, end) ~> range,
Idiv(Range { end, .. }, end) ~> UOp::index_const(0),
Idiv(x, d) if x.dtype() == DType::Index && matches!(d.vmax(), ConstValue::Int(v) if *v < 0)
=> Some(x.idiv(&d.neg()).neg()),
Idiv(x, d) if x.dtype() == DType::Index && matches!(x.vmax(), ConstValue::Int(v) if *v <= 0)
=> Some(x.neg().idiv(d).neg()),
Mod(x, d) if x.dtype() == DType::Index && matches!(x.vmax(), ConstValue::Int(v) if *v <= 0)
=> Some(x.neg().mod_(d).neg()),
Mod(x, d) if x.dtype() == DType::Index && matches!(d.vmax(), ConstValue::Int(v) if *v < 0)
=> Some(x.mod_(&d.neg())),
Mod(x, _n @const(n_val)) => {
let (vmin, vmax) = VminVmaxProperty::get(x);
trace!(
x.id = x.id,
vmin = ?vmin,
vmax = ?vmax,
n_val = ?n_val,
"Mod simplification check"
);
if let (ConstValue::Int(min), ConstValue::Int(max), ConstValue::Int(n_int)) = (vmin, vmax, n_val)
&& *min >= 0 && *max < n_int {
trace!(
n_int,
min = *min,
max = *max,
"Simplifying x % n → x"
);
return Some(Arc::clone(x));
}
None
},
Mod(Add[Mul[_a, n @const(_m_val)], b], n @const(_n_val)) => {
if !matches!(b.vmin(), ConstValue::Int(v) if *v >= 0) { return None; }
Some(b.mod_(n))
},
Mod(Add[Add[Mul[_a, n @const(_m_val)], b], c], n @const(_n_val)) => {
let bc = b.add(c);
if !matches!(bc.vmin(), ConstValue::Int(v) if *v >= 0) { return None; }
Some(bc.mod_(n))
},
Idiv(Add[Mul[a, n @const(_m_val)], b], n @const(n_val)) => {
let n_int = n_val.try_int()?;
if n_int <= 0 { return None; }
let (vmin, vmax) = VminVmaxProperty::get(b);
if let (ConstValue::Int(min), ConstValue::Int(max)) = (vmin, vmax)
&& *min >= 0 && *max < n_int {
trace!(
?n_val,
a.id = a.id,
min = *min,
max = *max,
"Idiv factor-out: (a * n + b) / n → a (when 0 <= b < n)"
);
return Some(Arc::clone(a));
}
if !matches!(b.vmin(), ConstValue::Int(v) if *v >= 0) { return None; }
let b_div_n = b.idiv(n);
Some(a.add(&b_div_n))
},
Idiv(x, _n @const(n_val)) => {
let (vmin, vmax) = VminVmaxProperty::get(x);
if let (ConstValue::Int(min), ConstValue::Int(max), ConstValue::Int(n_int)) = (vmin, vmax, n_val)
&& n_int > 0 {
let min_div = *min / n_int;
let max_div = *max / n_int;
if min_div == max_div {
trace!(
min = *min,
max = *max,
n_int,
result = min_div,
"Idiv cancel: x / n → k (all values in same bucket)"
);
return Some(UOp::const_(x.dtype(), ConstValue::Int(min_div)));
}
}
None
},
Idiv(x @ Add(_, _), n @const(n_val)) => {
fn extract_const_sum(uop: &Arc<UOp>) -> (Arc<UOp>, i64) {
match uop.op() {
Op::Binary(BinaryOp::Add, left, right) => {
if let Op::Const(cv) = right.op()
&& let ConstValue::Int(v) = cv.0 {
let (inner, inner_sum) = extract_const_sum(left);
return (inner, inner_sum + v);
}
if let Op::Const(cv) = left.op()
&& let ConstValue::Int(v) = cv.0 {
let (inner, inner_sum) = extract_const_sum(right);
return (inner, inner_sum + v);
}
let (left_inner, left_sum) = extract_const_sum(left);
let (right_inner, right_sum) = extract_const_sum(right);
if left_sum != 0 || right_sum != 0 {
let new_add = left_inner.try_add(&right_inner).ok();
if let Some(rebuilt) = new_add {
return (rebuilt, left_sum + right_sum);
}
}
(Arc::clone(uop), 0)
}
_ => (Arc::clone(uop), 0),
}
}
let (x_without_const, const_sum) = extract_const_sum(x);
if const_sum == 0 {
return None; }
let (vmin, vmax) = VminVmaxProperty::get(&x_without_const);
if let (ConstValue::Int(min), ConstValue::Int(max), ConstValue::Int(n_int)) = (vmin, vmax, n_val)
&& n_int > 0 {
let min_div = *min / n_int;
let max_div = *max / n_int;
if min_div != max_div {
return None;
}
let min_c_div = (*min + const_sum).div_euclid(n_int);
let max_c_div = (*max + const_sum).div_euclid(n_int);
if min_div == min_c_div && max_div == max_c_div {
return x_without_const.try_div(&Arc::clone(n)).ok();
}
}
None
},
Idiv(Add[a, Mul[Idiv(x, n @const(n_val)), n]], n) => {
let (vmin, vmax) = VminVmaxProperty::get(a);
if let (ConstValue::Int(min), ConstValue::Int(max), ConstValue::Int(n_int)) = (vmin, vmax, n_val)
&& *min >= 0 && *max < n_int && n_int > 0 {
return Some(x.idiv(n));
}
None
},
Idiv(Add[x, _c @const(c_val)], d @const(d_val)) => {
let c_int = c_val.try_int()?;
let d_int = d_val.try_int()?;
if d_int <= 0 || c_int <= 0 { return None; }
let (vmin, vmax) = VminVmaxProperty::get(x);
if let (ConstValue::Int(min), ConstValue::Int(max)) = (vmin, vmax)
&& *min >= 0
{
let max_rem = if max - min >= d_int - 1 || *min % d_int > *max % d_int {
d_int - 1
} else {
*max % d_int
};
if max_rem + c_int < d_int {
return Some(x.idiv(d));
}
}
None
},
Idiv(Add[x, _c @const(c_val)], d @const(d_val)) => {
let c_int = c_val.try_int()?;
let d_int = d_val.try_int()?;
if d_int <= 0 { return None; }
let c_mod_d = c_int % d_int;
let c_div_d = c_int / d_int;
if c_mod_d == c_int { return None; }
let (vmin, _) = VminVmaxProperty::get(x);
if let ConstValue::Int(min) = vmin {
if min + c_int < 0 || min + c_mod_d < 0 { return None; }
} else { return None; }
let remainder_const = UOp::const_(d.dtype(), ConstValue::Int(c_mod_d));
let inner = x.add(&remainder_const);
let div_result = inner.idiv(d);
let quotient_const = UOp::const_(d.dtype(), ConstValue::Int(c_div_d));
Some(div_result.add("ient_const))
},
Idiv(Add[x, _c @const(c_val)], d @const(d_val)) => {
let c_int = c_val.try_int()?;
let d_int = d_val.try_int()?;
if d_int <= 0 { return None; }
let (x_vmin, x_vmax) = VminVmaxProperty::get(x);
let n_expr = x.add(&UOp::const_(x.dtype(), c_val));
let n_vmin = n_expr.vmin();
if let (ConstValue::Int(_), ConstValue::Int(xmax)) = (x_vmin, x_vmax)
&& let ConstValue::Int(nmin) = n_vmin
&& *xmax <= 0 && *nmin >= 0
{
let c_mod_d = c_int.rem_euclid(d_int);
let c_div_d = c_int.div_euclid(d_int);
let c_mod_const = UOp::const_(d.dtype(), ConstValue::Int(c_mod_d));
let d_minus_1 = UOp::const_(d.dtype(), ConstValue::Int(d_int - 1));
let inner = c_mod_const.add(x).sub(&d_minus_1).neg();
let div_result = inner.idiv(d).neg();
let quotient_const = UOp::const_(d.dtype(), ConstValue::Int(c_div_d));
return Some(div_result.add("ient_const));
}
None
},
for op in binary [Idiv, Mod] {
d @ op(x, y) if d.dtype() == DType::Index => crate::symbolic::divmod::fold_divmod_general(op, x, y),
},
}
}
pub fn division_dsl_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Fdiv(zero1 @ @zero, @zero) if zero1.dtype().is_float()
~> UOp::const_(zero1.dtype(), ConstValue::Float(f64::NAN)),
Fdiv(Mul[_, zero1 @ @zero], @zero) if zero1.dtype().is_float()
~> UOp::const_(zero1.dtype(), ConstValue::Float(f64::NAN)),
Fdiv(x, x) => x.dtype().scalar().map(|dt| UOp::const_(x.dtype(), ConstValue::one(dt))),
Fdiv(Mul(x, y), y) ~> x.clone(),
Idiv(Mul(x, y), y) ~> x.clone(),
}
}
fn can_safe_cast(to: &DType, from: &DType) -> bool {
use morok_dtype::ScalarDType;
let to_scalar = match to {
DType::Scalar(s) => *s,
DType::Vector { scalar, .. } => *scalar,
_ => return false,
};
let from_scalar = match from {
DType::Scalar(s) => *s,
DType::Vector { scalar, .. } => *scalar,
_ => return false,
};
if to_scalar == from_scalar {
return true;
}
let (to_bits, to_signed, to_float) = match to_scalar {
ScalarDType::Bool => (1, false, false),
ScalarDType::Int8 => (8, true, false),
ScalarDType::Int16 => (16, true, false),
ScalarDType::Int32 => (32, true, false),
ScalarDType::Int64 => (64, true, false),
ScalarDType::UInt8 => (8, false, false),
ScalarDType::UInt16 => (16, false, false),
ScalarDType::UInt32 => (32, false, false),
ScalarDType::UInt64 => (64, false, false),
ScalarDType::Float16 | ScalarDType::BFloat16 => (16, true, true),
ScalarDType::Float32 => (32, true, true),
ScalarDType::Float64 => (64, true, true),
_ => return false,
};
let (from_bits, from_signed, from_float) = match from_scalar {
ScalarDType::Bool => (1, false, false),
ScalarDType::Int8 => (8, true, false),
ScalarDType::Int16 => (16, true, false),
ScalarDType::Int32 => (32, true, false),
ScalarDType::Int64 => (64, true, false),
ScalarDType::UInt8 => (8, false, false),
ScalarDType::UInt16 => (16, false, false),
ScalarDType::UInt32 => (32, false, false),
ScalarDType::UInt64 => (64, false, false),
ScalarDType::Float16 | ScalarDType::BFloat16 => (16, true, true),
ScalarDType::Float32 => (32, true, true),
ScalarDType::Float64 => (64, true, true),
_ => return false,
};
if to_float != from_float {
return false;
}
if to_float {
return from_bits >= to_bits;
}
if to_signed == from_signed {
return from_bits >= to_bits;
}
if !to_signed && from_signed {
return from_bits > to_bits;
}
false
}
pub fn cast_dsl_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Cast { src: _c @const(c_val), dtype } => c_val.cast(dtype).map(|v| UOp::const_(dtype.clone(), v)),
Cast { src: x, dtype } if x.dtype() == *dtype ~> x.clone(),
Cast { src: Cast { src: x, dtype: intermediate }, dtype: outer }
if x.dtype() == *outer && can_safe_cast(outer, intermediate)
~> x.clone(),
Cast { src: Cast { src: x, dtype: intermediate }, dtype: outer }
if can_safe_cast(&x.dtype(), intermediate)
~> |x, outer| x.cast(outer.clone()),
}
}
fn range_based_cast_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Cast { src: Cast { src: x, dtype: intermediate }, dtype: outer }
if (x.dtype().is_int() || x.dtype() == DType::Index)
&& (intermediate.is_int() || *intermediate == DType::Index)
=> {
let (vmin, vmax) = VminVmaxProperty::get(x);
let (imin, imax) = match intermediate.scalar() {
Some(ScalarDType::Int8) => (i8::MIN as i64, i8::MAX as i64),
Some(ScalarDType::Int16) => (i16::MIN as i64, i16::MAX as i64),
Some(ScalarDType::Int32) => (i32::MIN as i64, i32::MAX as i64),
Some(ScalarDType::Int64) => (i64::MIN, i64::MAX),
Some(ScalarDType::UInt8) => (0, u8::MAX as i64),
Some(ScalarDType::UInt16) => (0, u16::MAX as i64),
Some(ScalarDType::UInt32) => (0, u32::MAX as i64),
_ => return None,
};
if let (ConstValue::Int(vmin_v), ConstValue::Int(vmax_v)) = (vmin, vmax)
&& imin <= *vmin_v && *vmax_v <= imax {
return Some(x.cast(outer.clone()));
}
None
},
}
}
pub fn term_combining_dsl_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Add(x, x) ~> 2.into_uop(x.dtype()).mul(x),
Add(Mul[x, c1 @const(c1_val)], Mul[x, _c2 @const(c2_val)])
~> x.mul(&eval_add_typed(c1_val, c2_val, c1.dtype().base())
.expect("failed to add constants")
.into_uop(c1.dtype())),
Add[x, Mul[x, c @const(c_val)]] ~> {
let one = ConstValue::one(c.dtype().base());
let new_c = eval_add_typed(c_val, one, c.dtype().base()).expect("failed to add constants");
x.mul(&UOp::const_(c.dtype(), new_c))
},
Add[Add[y, Mul[x, c0 @const(c0_val)]], Mul[x, _c1 @const(c1_val)]] ~> {
let new_c = eval_add_typed(c0_val, c1_val, c0.dtype().base()).expect("failed to add constants");
let xc = x.mul(&UOp::const_(c0.dtype(), new_c));
y.add(&xc)
},
Add[Add[y, x], Mul[x, c @const(c_val)]] ~> {
let one = ConstValue::one(c.dtype().base());
let new_c = eval_add_typed(c_val, one, c.dtype().base()).expect("failed to add constants");
let xc = x.mul(&UOp::const_(c.dtype(), new_c));
y.add(&xc)
},
Add[Add[y, Mul[x, c @const(c_val)]], x] ~> {
let one = ConstValue::one(c.dtype().base());
let new_c = eval_add_typed(c_val, one, c.dtype().base()).expect("failed to add constants");
let xc = x.mul(&UOp::const_(c.dtype(), new_c));
y.add(&xc)
},
Add[Add[y, x], x] ~> {
let x2 = 2.into_uop(x.dtype()).mul(x);
y.add(&x2)
},
Fdiv(Fdiv(x, x2), x3)
if !Arc::ptr_eq(x2, x3)
=> {
let denom = x2.try_mul(x3).ok()?;
x.try_div(&denom).ok()
},
}
}
pub fn advanced_division_dsl_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Idiv(Idiv(a, b @const(b_val)), _c @const(c_val)) if !b_val.is_zero() && !c_val.is_zero() ~> {
let mul = eval_mul_typed(b_val, c_val, b.dtype().base()).expect("failed to multiply constants");
a.idiv(&UOp::const_(b.dtype(), mul))
},
Idiv(expr, divisor @ @const) => expr.divides(divisor),
Mod(x, c @const(c_val)) => crate::symbolic::divmod::fold_divmod_congruence(x, c, c_val, true),
Idiv(x, c @const(c_val)) => crate::symbolic::divmod::fold_divmod_congruence(x, c, c_val, false),
Idiv(Add(a, b), c @ @const) => Some(a.divides(c)?.add(&b.divides(c)?)),
Idiv(Sub(a, b), c @ @const) => Some(a.divides(c)?.sub(&b.divides(c)?)),
Mul[y @const(_yv), Add[x, c @const(_cv)]] if x.dtype() == DType::Index ~> y.mul(x).add(&y.mul(c)),
Idiv(Add[Idiv(a, c1 @const(c1_val)), _c2 @const(c2_val)], _c3 @const(c3_val)) => {
let c1_int = c1_val.try_int().expect("failed to extract int");
let c2_int = c2_val.try_int().expect("failed to extract int");
let c3_int = c3_val.try_int().expect("failed to extract int");
if c1_int <= 0 || c3_int <= 0 { return None; }
let a_vmin = a.vmin().try_int().expect("failed to extract int from vmin");
let a_vmax = a.vmax().try_int().expect("failed to extract int from vmax");
if !((a_vmin >= 0 && c2_int >= 0) || (a_vmax <= 0 && c2_int <= 0)) { return None; }
let c1_times_c2 = eval_mul_typed(c1_val, c2_val, c1.dtype().base()).expect("failed to evaluate cprod");
let c1_times_c3 = eval_mul_typed(c1_val, c3_val, c1.dtype().base()).expect("failed to evaluate cprod");
Some(a.add(&UOp::const_(c1.dtype(), c1_times_c2))
.idiv(&UOp::const_(c1.dtype(), c1_times_c3)))
},
}
}
pub fn alu_folding_dsl_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Add[Add[x, c1 @const(c1_val)], _c2 @const(c2_val)] ~> {
let csum = eval_add_typed(c1_val, c2_val, c1.dtype().base()).expect("failed to add constants");
x.add(&UOp::const_(c1.dtype(), csum))
},
Add[Add[x, c @const(_c_val)], y] if !matches!(y.op(), Op::Const(_)) ~> x.add(y).add(c),
Mul[Mul[x, c1 @const(c1_val)], _c2 @const(c2_val)] ~> {
let cmul = eval_mul_typed(c1_val, c2_val, c1.dtype().base()).expect("failed to multiply constants");
x.mul(&UOp::const_(c1.dtype(), cmul))
},
Mul[Mul[x, c @const(_c_val)], y] if !matches!(y.op(), Op::Const(_)) ~> x.mul(y).mul(c),
And[And[x, c1 @const(c1_val)], _c2 @const(c2_val)]
=> eval_binary_op(BinaryOp::And, c1_val, c2_val).map(|r| x.and_(&UOp::const_(c1.dtype(), r))),
Or[Or[x, c1 @const(c1_val)], _c2 @const(c2_val)]
=> eval_binary_op(BinaryOp::Or, c1_val, c2_val).map(|r| x.or_(&UOp::const_(c1.dtype(), r))),
Xor[Xor[x, c1 @const(c1_val)], _c2 @const(c2_val)]
=> eval_binary_op(BinaryOp::Xor, c1_val, c2_val).map(|r| x.xor(&UOp::const_(c1.dtype(), r))),
Max(Max(x, c1 @const(c1_val)), _c2 @const(c2_val))
=> eval_binary_op(BinaryOp::Max, c1_val, c2_val).map(|r| x.try_max(&UOp::const_(c1.dtype(), r)).expect("max failed")),
Add[Sub(x, c1 @const(c1_val)), _c2 @const(c2_val)] ~> {
let diff_val = eval_sub_typed(c2_val, c1_val, c1.dtype().base()).expect("failed to subtract constants");
if let ConstValue::Int(v) = diff_val && v < 0 {
x.sub(&(-v).into_uop(c1.dtype()))
} else {
x.add(&UOp::const_(c1.dtype(), diff_val))
}
},
Sub(Add(x, c1 @const(c1_val)), _c2 @const(c2_val)) ~> {
let diff_val = eval_sub_typed(c1_val, c2_val, c1.dtype().base()).expect("failed to subtract constants");
if let Some(v) = diff_val.try_int() && v < 0 {
x.sub(&(-v).into_uop(c1.dtype()))
} else {
x.add(&UOp::const_(c1.dtype(), diff_val))
}
},
Sub(Sub(x, c1 @const(c1_val)), _c2 @const(c2_val)) ~> {
let csum = eval_add_typed(c1_val, c2_val, c1.dtype().base()).expect("failed to add constants");
x.sub(&UOp::const_(c1.dtype(), csum))
},
Sub(a, Sub(b, x)) ~> x.add(&a.sub(b)),
Mul[_neg @const(nv), Add[x, c @const(cv)]] if nv.is_neg_one() => {
let neg_one = ConstValue::neg_one(c.dtype().base())?;
let neg_cv = eval_mul_typed(cv, neg_one, c.dtype().base()).expect("failed to negate constant");
Some(UOp::neg(x).add(&UOp::const_(c.dtype(), neg_cv)))
},
}
}
pub fn dead_loop_patterns() -> &'static TypedPatternMatcher {
use crate::symbolic::dce::reduce_identity;
fn filter_dead_ranges(end_op: &Arc<UOp>) -> Arc<UOp> {
let Op::End { computation, ranges } = end_op.op() else { unreachable!("filter_dead_ranges called on non-End") };
let live_ranges: SmallVec<[Arc<UOp>; 4]> = ranges.iter().filter(|r| !is_empty_range(r)).cloned().collect();
if live_ranges.is_empty() {
Arc::clone(computation)
} else {
computation.end(live_ranges)
}
}
fn is_trivial_range(uop: &Arc<UOp>) -> bool {
let (vmin, vmax) = VminVmaxProperty::get(uop);
vmin == vmax
}
fn trivial_range_value(uop: &Arc<UOp>) -> Arc<UOp> {
let (vmin, _) = VminVmaxProperty::get(uop);
UOp::const_(uop.dtype(), *vmin)
}
crate::cached_patterns! {
r @ Range(_) if is_empty_range(r) ~> UOp::index_const(0),
r @ Range { end: Const(_) } if is_trivial_range(r) ~> trivial_range_value(r),
end_op @ End { ranges, .. } if ranges.iter().any(is_empty_range) ~> filter_dead_ranges(end_op),
rop @ Reduce { ranges, reduce_op: op, .. } if !ranges.is_empty() && ranges.iter().all(is_empty_range)
~> reduce_identity(*op, rop.dtype()),
}
}
pub fn vmin_vmax_collapse_patterns() -> &'static TypedPatternMatcher {
use morok_ir::uop::properties::SoundVminVmaxProperty;
fn is_collapsible(uop: &Arc<UOp>) -> bool {
matches!(uop.op(), Op::Binary(..) | Op::Unary(..) | Op::Ternary(..) | Op::DefineVar { .. } | Op::Special { .. })
}
fn try_collapse(uop: &Arc<UOp>) -> Option<Arc<UOp>> {
let (vmin, vmax) = SoundVminVmaxProperty::get(uop).as_ref()?;
if vmin == vmax { Some(uop.const_like(*vmin)) } else { None }
}
crate::cached_patterns! {
for op in binary [Add, Mul, Sub, Mod, Max, Pow, Idiv, Fdiv, And, Or, Xor, Shl, Shr, Lt, Le, Eq, Ne, Gt, Ge] {
r @ op(_, _) if is_collapsible(r) => { let _ = op; try_collapse(r) },
},
for op in unary [Sqrt, Exp2, Log2, Sin, Reciprocal, Trunc, Not, Floor, Ceil, Round] {
r @ op(_) if is_collapsible(r) => { let _ = op; try_collapse(r) },
},
for op in ternary [Where, MulAcc] {
r @ op(_, _, _) if is_collapsible(r) => { let _ = op; try_collapse(r) },
},
r @ DefineVar { name: _, min_val: _, max_val: _ } => try_collapse(r),
r @ Special { end: _, name: _ } => try_collapse(r),
}
}
pub fn dce_dsl_simple_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Where(cond, true_val, false_val) => {
match VminVmaxProperty::get(cond) {
(ConstValue::Bool(true), ConstValue::Bool(true)) => Some(Arc::clone(true_val)),
(ConstValue::Bool(false), ConstValue::Bool(false)) => Some(Arc::clone(false_val)),
_ => None,
}
},
Where(_, t, t) ~> |t| Arc::clone(t),
Where(x, _t @const(t_val), _f @const(f_val))
if x.dtype() == DType::Bool && t_val == ConstValue::Bool(true) && f_val == ConstValue::Bool(false)
~> Arc::clone(x),
Where(x, _t @const(t_val), _f @const(f_val))
if x.dtype() == DType::Bool && t_val == ConstValue::Bool(false) && f_val == ConstValue::Bool(true)
~> x.not(),
Where(a, Where(b, c, d), d) => {
let combined_cond = a.and_(b);
UOp::try_where(combined_cond, Arc::clone(c), Arc::clone(d)).ok()
},
}
}
pub fn dce_dsl_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Where(Not(cond), t, f)
if !has_invalid(f)
=> UOp::try_where(Arc::clone(cond), Arc::clone(f), Arc::clone(t)).ok(),
}
}
pub fn after_simplification_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
After { passthrough, deps } if !deps.is_empty() => {
let mut new_deps = smallvec::SmallVec::<[Arc<UOp>; 4]>::new();
let mut changed = false;
for dep in deps {
if matches!(dep.op(), Op::Range { .. } | Op::Store { .. } | Op::End { .. } | Op::Kernel { .. } | Op::Barrier { .. } | Op::Unroll { .. }) {
new_deps.push(Arc::clone(dep));
} else {
for child in dep.op().sources() {
new_deps.push(child);
}
changed = true;
}
}
if changed {
if new_deps.is_empty() {
Some(Arc::clone(passthrough))
} else {
Some(passthrough.after(new_deps))
}
} else {
None
}
},
After { passthrough, deps } if deps.is_empty() ~> Arc::clone(passthrough),
}
}
pub fn pm_move_where_on_load() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Where(cond, idx @ Index { buffer, indices, gate: None }, f @ const(false_val)) if false_val.is_zero() => {
where_on_load_index_transform(cond, buffer, indices, f, idx.dtype())
},
Where(cond, f @ const(false_val), idx @ Index { buffer, indices, gate: None }) if false_val.is_zero() => {
let not_cond = cond.not();
where_on_load_index_transform(¬_cond, buffer, indices, f, idx.dtype())
},
}
}
fn has_invalid(uop: &Arc<UOp>) -> bool {
match uop.op() {
Op::Invalid => true,
Op::Vectorize { elements } => elements.iter().any(|e| matches!(e.op(), Op::Invalid)),
_ => false,
}
}
fn where_on_load_index_transform(
cond: &Arc<UOp>,
idx_buf: &Arc<UOp>,
indices: &SmallVec<[Arc<UOp>; 4]>,
false_val: &Arc<UOp>,
index_dtype: DType,
) -> Option<Arc<UOp>> {
let c1_clauses = split_and_clauses(cond);
let existing_valid = indices.first()?.get_valid();
let c2_clauses: Vec<Arc<UOp>> = if matches!(existing_valid.op(), Op::Const(cv) if cv.0 == ConstValue::Bool(true)) {
vec![]
} else {
split_and_clauses(&existing_valid)
};
let duplicate_ids: std::collections::HashSet<u64> =
c1_clauses.iter().filter(|c| c2_clauses.iter().any(|c2| c.id == c2.id)).map(|c| c.id).collect();
let mut index_ranges = std::collections::HashSet::new();
let mut idx_indices = std::collections::HashSet::new();
for idx in indices {
let mut visited = std::collections::HashSet::new();
let mut stack = vec![idx.clone()];
while let Some(node) = stack.pop() {
if !visited.insert(Arc::as_ptr(&node)) {
continue;
}
match node.op() {
Op::Range { .. } => {
index_ranges.insert(node.id);
}
Op::Index { .. } => {
idx_indices.insert(node.id);
}
_ => {}
}
node.op().map_child(|child| {
if !visited.contains(&Arc::as_ptr(child)) {
stack.push(child.clone());
}
});
}
}
let (moved_clauses, remaining_clauses): (Vec<_>, Vec<_>) = c1_clauses.iter().cloned().partition(|clause| {
if duplicate_ids.contains(&clause.id) {
return true; }
let mut ranges_in_scope = true;
let mut has_index_deps = false;
let mut visited = std::collections::HashSet::new();
let mut stack = vec![clause.clone()];
while let Some(node) = stack.pop() {
if !visited.insert(Arc::as_ptr(&node)) {
continue;
}
match node.op() {
Op::Range { .. } if !index_ranges.contains(&node.id) => {
ranges_in_scope = false;
break; }
Op::Index { .. } if !idx_indices.contains(&node.id) => {
has_index_deps = true;
break; }
_ => {}
}
node.op().map_child(|child| {
if !visited.contains(&Arc::as_ptr(child)) {
stack.push(child.clone());
}
});
}
ranges_in_scope && !has_index_deps
});
let actually_moved: Vec<_> = moved_clauses.into_iter().filter(|c| !duplicate_ids.contains(&c.id)).collect();
if actually_moved.is_empty() && duplicate_ids.is_empty() {
return None; }
let mut validity_clauses: Vec<Arc<UOp>> = actually_moved;
validity_clauses.extend(c2_clauses);
let clean_idx = indices.first()?.get_idx();
let new_idx = if validity_clauses.is_empty() {
clean_idx
} else {
let combined_valid = validity_clauses.into_iter().reduce(|a, b| a.and_(&b)).unwrap();
clean_idx.valid(combined_valid)
};
let mut new_indices = indices.clone();
new_indices[0] = new_idx;
let new_index = UOp::index()
.buffer(idx_buf.clone())
.indices(new_indices)
.call()
.expect("where_on_load_index_transform: INDEX construction failed")
.with_dtype(index_dtype);
if remaining_clauses.is_empty() {
Some(new_index)
} else {
let remaining_cond = remaining_clauses.into_iter().reduce(|a, b| a.and_(&b)).unwrap();
UOp::try_where(remaining_cond, new_index, false_val.clone()).ok()
}
}
fn split_and_clauses(cond: &Arc<UOp>) -> Vec<Arc<UOp>> {
match cond.op() {
Op::Binary(BinaryOp::And, left, right) => {
let mut result = split_and_clauses(left);
result.extend(split_and_clauses(right));
result
}
_ => vec![cond.clone()],
}
}
pub fn cast_where_dsl_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Cast { src: Where(s, a, b), dtype } ~> {
let cast_a = a.cast(dtype.clone());
let cast_b = b.cast(dtype.clone());
UOp::try_where(s.clone(), cast_a, cast_b).expect("failed to create WHERE")
},
}
}
pub fn comparison_dsl_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
for op in binary [Lt, Le, Eq, Ne, Gt, Ge] {
op(x, y) => {
if Arc::ptr_eq(x, y) && !x.dtype().is_float() {
let result = match op {
BinaryOp::Lt | BinaryOp::Gt | BinaryOp::Ne => ConstValue::Bool(false),
BinaryOp::Le | BinaryOp::Ge | BinaryOp::Eq => ConstValue::Bool(true),
_ => return None,
};
return Some(UOp::const_(DType::Bool, result));
}
if let (Some(a_val), Some(b_val)) = (get_const_value(x), get_const_value(y))
&& let Some(result) = eval_binary_op(op, a_val, b_val)
{
return Some(UOp::const_(DType::Bool, result));
}
if let Some(result) = ComparisonAnalyzer::analyze(op, x, y) {
return Some(result.into_uop(DType::Bool));
}
None
},
},
Lt(Add[c0 @const(c0_val), x], _c1 @const(c1_val)) ~> {
let diff = eval_sub_typed(c1_val, c0_val, c0.dtype().base()).expect("failed to evaluate sub");
x.try_cmplt(&UOp::const_(c0.dtype(), diff)).expect("failed to create cmplt")
},
Lt(Mul[x, _c1 @const(c1v)], Mul[y, _c2 @const(c2v)])
if c1v.is_neg_one() && c2v.is_neg_one()
~> y.try_cmplt(x).expect("failed to create cmplt"),
Lt(Idiv(x, _d @const(d_val)), _c @const(c_val)) => {
let d_int = d_val.try_int()?;
let c_int = c_val.try_int()?;
if d_int <= 0 { return None; }
let bound = if c_int > 0 {
c_int * d_int
} else {
c_int * d_int - (d_int - 1)
};
Some(x.try_cmplt(&UOp::const_(x.dtype(), ConstValue::Int(bound))).expect("failed to create cmplt"))
},
Lt(Mul[_c0 @const(c0_val), x], _c1 @const(c1_val))
if x.dtype() == DType::Index
=> {
let c0 = c0_val.try_int()?;
let c1 = c1_val.try_int()?;
if c0 > 0 && c1 > 0 {
let ceil_div = (c1 + c0 - 1) / c0;
return Some(x.try_cmplt(&UOp::index_const(ceil_div)).expect("failed to create cmplt"));
}
if c0 < 0 && c0 != -1 && c1 <= 0 {
let neg_c0 = -c0;
let neg_c1 = -c1;
let floor_div = neg_c1 / neg_c0; return Some(x.neg().try_cmplt(&UOp::index_const(-floor_div)).expect("failed to create cmplt"));
}
None
},
Lt(x, _c @const(cv)) if x.dtype() == DType::Index => {
let c_int = cv.try_int()?;
if c_int <= 0 { return None; }
lt_folding(x, c_int)
},
}
}
fn lt_folding(x: &Arc<UOp>, c_int: i64) -> Option<Arc<UOp>> {
let terms = x.split_uop(BinaryOp::Add);
if terms.len() < 2 {
return None;
}
let mut unit_terms = Vec::new();
let mut non_unit_factors = Vec::new();
for t in &terms {
let f = t.const_factor();
if f == 1 {
unit_terms.push(Arc::clone(t));
} else {
non_unit_factors.push(f);
}
}
if non_unit_factors.is_empty() || unit_terms.is_empty() {
return None;
}
let mut d = c_int.unsigned_abs() as i64;
for &f in &non_unit_factors {
d = gcd(d, f);
}
if d <= 1 {
return None;
}
let unit_sum = super::divmod::uop_sum(&unit_terms, x);
let (us_vmin, us_vmax) = VminVmaxProperty::get(&unit_sum);
let us_min = us_vmin.try_int()?;
let us_max = us_vmax.try_int()?;
if us_min < 0 || us_max >= d {
return None;
}
let non_unit_terms: Vec<Arc<UOp>> = terms.iter().filter(|t| t.const_factor() != 1).cloned().collect();
let non_unit_sum = super::divmod::uop_sum(&non_unit_terms, x);
let q = non_unit_sum.divides_int(d)?;
q.try_cmplt(&UOp::index_const(c_int / d)).ok()
}
fn gcd(a: i64, b: i64) -> i64 {
let (mut a, mut b) = (a.unsigned_abs(), b.unsigned_abs());
while b != 0 {
let t = b;
b = a % b;
a = t;
}
a as i64
}
pub fn boolean_dsl_simple_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Not(Not(x)) ~> x.clone(),
Xor(x, x) => x.dtype().scalar().map(|dt| UOp::const_(x.dtype(), ConstValue::zero(dt))),
Or[t @const(t_val), _] if t_val == ConstValue::Bool(true) ~> t.clone(),
And[f @const(f_val), _] if f_val == ConstValue::Bool(false) ~> f.clone(),
And[_c @const(c_val), x] if c_val == ConstValue::Bool(true) ~> x.clone(),
Or[_c @const(c_val), x] if c_val == ConstValue::Bool(false) ~> x.clone(),
}
}
pub fn boolean_dsl_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Or[x, Not(x)] if x.dtype() == DType::Bool ~> UOp::const_(DType::Bool, ConstValue::Bool(true)),
And[x, Not(x)] if x.dtype() == DType::Bool ~> UOp::const_(DType::Bool, ConstValue::Bool(false)),
And[Not(x), Not(y)] ~> x.or_(y).not(),
Or[Not(x), Not(y)] ~> x.and_(y).not(),
}
}
pub fn minmax_dsl_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Max(x, y) => {
let (x_vmin, x_vmax) = VminVmaxProperty::get(x);
let (y_vmin, y_vmax) = VminVmaxProperty::get(y);
if cv_ge(x_vmin, y_vmax) {
return Some(Arc::clone(x));
}
if cv_ge(y_vmin, x_vmax) {
return Some(Arc::clone(y));
}
None
},
}
}
pub fn where_bound_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Where(Lt(x, c), t, f) => {
let (x_vmin, x_vmax) = VminVmaxProperty::get(x);
let (c_vmin, c_vmax) = VminVmaxProperty::get(c);
if cv_lt(x_vmax, c_vmin) { return Some(Arc::clone(t)); }
if cv_ge(x_vmin, c_vmax) { return Some(Arc::clone(f)); }
None
},
}
}
fn cv_ge(a: &ConstValue, b: &ConstValue) -> bool {
match (a, b) {
(ConstValue::Int(a), ConstValue::Int(b)) => a >= b,
(ConstValue::UInt(a), ConstValue::UInt(b)) => a >= b,
(ConstValue::Float(a), ConstValue::Float(b)) => a >= b,
_ => false,
}
}
fn cv_lt(a: &ConstValue, b: &ConstValue) -> bool {
match (a, b) {
(ConstValue::Int(a), ConstValue::Int(b)) => a < b,
(ConstValue::UInt(a), ConstValue::UInt(b)) => a < b,
(ConstValue::Float(a), ConstValue::Float(b)) => a < b,
_ => false,
}
}
pub fn power_dsl_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Pow(x, c @const(cv)) => simplify_pow(x, c, cv),
Pow(c @const(cv), x) => simplify_pow_const_base(c, cv, x),
}
}
fn simplify_pow(x: &Arc<UOp>, c: &Arc<UOp>, cv: ConstValue) -> Option<Arc<UOp>> {
if x.dtype().vcount() > 1 {
return None;
}
let f = match cv {
ConstValue::Float(f) => f,
ConstValue::Int(i) => i as f64,
_ => return None,
};
if f == 0.0 {
return x.dtype().scalar().map(|dt| UOp::const_(x.dtype(), ConstValue::one(dt)));
}
if f == 1.0 {
return Some(Arc::clone(x));
}
if f < 0.0 {
let recip = UOp::try_reciprocal(x).ok()?;
let neg_c = UOp::const_(c.dtype(), ConstValue::Float(-f));
return recip.try_pow(&neg_c).ok();
}
let half_check = (f - 0.5).floor() + 0.5;
if half_check == f {
let n = UOp::const_(c.dtype(), ConstValue::Float(f - 0.5));
let pow_n = x.try_pow(&n).ok()?;
let sqrt_x = x.try_sqrt().ok()?;
return pow_n.try_mul(&sqrt_x).ok();
}
if f == f.floor() {
let half = UOp::const_(c.dtype(), ConstValue::Float((f as i64 / 2) as f64));
let y = x.try_pow(&half).ok()?;
let y2 = y.try_mul(&y).ok()?;
if (f as i64) % 2 == 1 {
return y2.try_mul(x).ok();
}
return Some(y2);
}
None
}
fn simplify_pow_const_base(c: &Arc<UOp>, cv: ConstValue, x: &Arc<UOp>) -> Option<Arc<UOp>> {
if c.dtype().vcount() > 1 {
return None;
}
let f = match cv {
ConstValue::Float(f) => f,
ConstValue::Int(i) => i as f64,
_ => return None,
};
if f == 1.0 {
return Some(Arc::clone(c));
}
if f > 0.0 {
let log2_c = UOp::const_(x.dtype(), ConstValue::Float(f.log2()));
let product = x.try_mul(&log2_c).ok()?;
return UOp::try_exp2(&product).ok();
}
None
}
fn alu_vectorize_reorder_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
for op in binary [Add, Mul, Sub, Mod, Max, Idiv, Fdiv, Pow, And, Or, Xor, Shl, Shr, Lt, Le, Eq, Ne, Gt, Ge] {
r @ op(Vectorize { elements: x_elems }, Vectorize { elements: y_elems })
if x_elems.len() == y_elems.len()
&& x_elems.len() > 1
&& x_elems.windows(2).all(|w| Arc::ptr_eq(&w[0], &w[1]))
&& y_elems.windows(2).all(|w| Arc::ptr_eq(&w[0], &w[1]))
=> {
let scalar_dtype = r.dtype().scalar_dtype();
let count = x_elems.len();
let scalar_alu = UOp::new(Op::Binary(op, x_elems[0].clone(), y_elems[0].clone()), scalar_dtype);
let elems: SmallVec<[Arc<UOp>; 4]> = std::iter::repeat_n(scalar_alu, count).collect();
Some(UOp::vectorize(elems))
},
},
}
}
fn ne_zero_fold_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Ne(x, _zero @const(zv)) if zv.is_zero() => {
let bool_dt = DType::Bool.vec(x.dtype().vcount());
Some(x.cast(bool_dt))
},
}
}
fn reciprocal_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Reciprocal(Mul(x, x)) => {
let rx = UOp::try_reciprocal(x).ok()?;
rx.try_mul(&rx).ok()
},
Reciprocal(Mul(Mul(x, x), x)) => {
let rx = UOp::try_reciprocal(x).ok()?;
let rx2 = rx.try_mul(&rx).ok()?;
rx2.try_mul(&rx).ok()
},
Reciprocal(Mul[x, c @const(_cv)]) => {
let rx = UOp::try_reciprocal(x).ok()?;
let rc = UOp::try_reciprocal(c).ok()?;
rx.try_mul(&rc).ok()
},
Mul[x, Reciprocal(Add[one @const(ov), x])] if ov.is_one() => {
let d = UOp::try_reciprocal(&one.add(x)).ok()?;
let one_uop = 1.0.into_uop(d.dtype());
one_uop.try_sub(&d).ok()
},
Mul[x, Mul[Reciprocal(Add[one @const(ov), x]), y]] if ov.is_one() => {
let d = UOp::try_reciprocal(&one.add(x)).ok()?;
let one_uop = 1.0.into_uop(d.dtype());
let one_minus_d = one_uop.try_sub(&d).ok()?;
y.try_mul(&one_minus_d).ok()
},
Mul[x, Add(Reciprocal(Add[one @const(ov), x]), y)] if ov.is_one() => {
let d = UOp::try_reciprocal(&one.add(x)).ok()?;
let one_uop = 1.0.into_uop(d.dtype());
let one_minus_d = one_uop.try_sub(&d).ok()?;
let xy = x.try_mul(y).ok()?;
one_minus_d.try_add(&xy).ok()
},
}
}
fn reduce_sym_patterns() -> &'static TypedPatternMatcher {
use morok_ir::types::ReduceOp;
crate::cached_patterns! {
Reduce { src: Mul[x, c @const(_cv)], ranges, reduce_op }
if *reduce_op == ReduceOp::Add
&& c.dtype().vcount() == 1
=> {
let new_reduce = x.reduce(ranges.clone(), ReduceOp::Add);
let c_typed = if c.dtype() == new_reduce.dtype() {
Arc::clone(c)
} else {
c.cast(new_reduce.dtype())
};
new_reduce.try_mul(&c_typed).ok()
},
reduce @ Reduce { src, ranges, reduce_op }
if matches!(reduce_op, ReduceOp::Add | ReduceOp::Max)
&& matches!(src.op(), Op::Binary(BinaryOp::Mul, _, _))
&& reduce.dtype() == src.dtype()
=> {
reduce_mul_chain_sym(src, ranges, *reduce_op)
},
}
}
fn reduce_mul_chain_sym(
src: &Arc<UOp>,
ranges: &SmallVec<[Arc<UOp>; 4]>,
reduce_op: morok_ir::types::ReduceOp,
) -> Option<Arc<UOp>> {
use morok_ir::types::ReduceOp;
if !matches!(reduce_op, ReduceOp::Add | ReduceOp::Max) {
return None;
}
let factors = src.split_uop(BinaryOp::Mul);
let range_ids: std::collections::HashSet<u64> = ranges.iter().map(|r| r.id).collect();
let mut inside = Vec::new();
let mut outside = Vec::new();
for factor in &factors {
let factor_ids = factor.backward_slice_ids();
let depends_on_range = range_ids.iter().any(|rid| factor_ids.contains(rid));
if !depends_on_range && (reduce_op != ReduceOp::Max || matches!(factor.vmin(), ConstValue::Int(v) if *v >= 0)) {
outside.push(Arc::clone(factor));
} else {
inside.push(Arc::clone(factor));
}
}
if outside.is_empty() {
return None;
}
let inside_prod = if inside.is_empty() {
src.const_like(ConstValue::one(src.dtype().base()))
} else {
inside.into_iter().reduce(|a, b| a.try_mul(&b).expect("mul failed")).unwrap()
};
let reduced = inside_prod.reduce(ranges.clone(), reduce_op);
let outside_prod = outside.into_iter().reduce(|a, b| a.try_mul(&b).expect("mul failed")).unwrap();
reduced.try_mul(&outside_prod).ok()
}
fn is_remove_from_sink_like(u: &Arc<UOp>) -> bool {
matches!(u.op(), Op::Unroll { .. } | Op::Noop | Op::Vectorize { .. } | Op::Sink { .. })
}
pub fn sym_phase3_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Mul[_neg @const(nv), Add(x, y)] if nv.is_neg_one() ~> x.neg().add(&y.neg()),
Mul[Add(x, y), c @const(_cv)] if x.dtype() == DType::Index ~> x.mul(c).add(&y.mul(c)),
Group { sources } if sources.len() == 1 ~> sources[0].clone(),
Sink { sources } if sources.iter().any(is_remove_from_sink_like) => {
let new_srcs: Vec<Arc<UOp>> = sources.iter().flat_map(|s| {
if is_remove_from_sink_like(s) { s.op().sources().to_vec() } else { vec![Arc::clone(s)] }
}).collect();
Some(UOp::sink(new_srcs))
},
Group { sources } if sources.iter().any(|s| is_remove_from_sink_like(s) || matches!(s.op(), Op::Group { .. })) => {
let new_srcs: Vec<Arc<UOp>> = sources.iter().flat_map(|s| {
if is_remove_from_sink_like(s) || matches!(s.op(), Op::Group { .. }) {
s.op().sources().to_vec()
} else { vec![Arc::clone(s)] }
}).collect();
Some(UOp::group(new_srcs))
},
End { computation, .. } if matches!(computation.op(), Op::Noop) ~> UOp::new(Op::Noop, DType::Void),
}
}
pub fn store_load_folding_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Store { index, value: Load { index, .. } } ~> UOp::new(Op::Noop, DType::Void),
Store { index: idx @ Index { buffer: buf, indices, gate: None }, value: Where(gate, alt, Load { index: idx2, .. }), ranges }
if idx.id == idx2.id && !indices.is_empty()
=> {
let original_idx = indices[0].clone();
let invalid = UOp::new(Op::Invalid, original_idx.dtype());
let gated_idx = UOp::try_where(gate.clone(), original_idx, invalid).ok()?;
let mut new_indices: SmallVec<[Arc<UOp>; 4]> = indices.clone();
new_indices[0] = gated_idx;
let new_index = UOp::index()
.buffer(buf.clone())
.indices(new_indices)
.call()
.ok()?;
if ranges.is_empty() {
Some(new_index.store(alt.clone()))
} else {
Some(new_index.store_with_ranges(alt.clone(), ranges.clone()))
}
},
}
}
pub fn where_alu_combining_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
for op in binary [Add, Mul, Sub, Max, And, Or, Xor] {
r @ op(Where(c, a @const(_a), b), Where(c, d @const(_d), e)) ~> {
let true_branch = UOp::new(Op::Binary(op, Arc::clone(a), Arc::clone(d)), r.dtype());
let false_branch = UOp::new(Op::Binary(op, Arc::clone(b), Arc::clone(e)), r.dtype());
UOp::try_where(Arc::clone(c), true_branch, false_branch).expect("failed to construct WHERE")
},
},
for op in binary [Add, Mul, Sub, Max, And, Or, Xor] {
r @ op(Where(c, a, b @const(_b)), Where(c, d, e @const(_e))) ~> {
let true_branch = UOp::new(Op::Binary(op, Arc::clone(a), Arc::clone(d)), r.dtype());
let false_branch = UOp::new(Op::Binary(op, Arc::clone(b), Arc::clone(e)), r.dtype());
UOp::try_where(Arc::clone(c), true_branch, false_branch).expect("failed to construct WHERE")
},
},
Add(Add(y, Where(c, t @const(_t), f)), Where(c, tt @const(_tt), ff)) ~> {
let true_sum = t.add(tt);
let false_sum = f.add(ff);
let combined = UOp::try_where(c.clone(), true_sum, false_sum).expect("failed to construct WHERE");
y.add(&combined)
},
Add(Add(y, Where(c, t, f @const(_f))), Where(c, tt, ff @const(_ff))) ~> {
let true_sum = t.add(tt);
let false_sum = f.add(ff);
let combined = UOp::try_where(c.clone(), true_sum, false_sum).expect("failed to construct WHERE");
y.add(&combined)
},
}
}
pub fn gep_pushing_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Gep { vector, .. } if vector.dtype() == DType::Void ~> Arc::clone(vector),
Gep { vector: Gep { vector: inner_vec, indices: inner_indices }, indices } => {
let composed: Vec<usize> = indices.iter().map(|&o| inner_indices.get(o).copied()).collect::<Option<Vec<_>>>()?;
Some(inner_vec.gep(composed))
},
Gep { vector: Vectorize { elements, .. }, indices } if indices.len() == 1 => elements.get(indices[0]).cloned(),
Gep { vector: Vectorize { elements, .. }, indices } if indices.len() > 1 => {
let selected: SmallVec<[Arc<UOp>; 4]> = indices.iter().filter_map(|&i| elements.get(i).cloned()).collect();
(selected.len() == indices.len()).then(|| UOp::vectorize(selected))
},
Gep { vector: c @const(_cv), .. } if c.dtype().vcount() == 1 ~> Arc::clone(c),
Gep { vector: v @ VConst { values }, indices } => {
let scalar_dtype = v.dtype().scalar_dtype();
if indices.len() == 1 {
values.get(indices[0]).map(|v| UOp::const_(scalar_dtype.clone(), *v))
} else {
let selected: Vec<_> = indices.iter().filter_map(|&i| values.get(i).cloned()).collect();
(selected.len() == indices.len()).then(|| UOp::vconst(selected, scalar_dtype.clone()))
}
},
Gep { vector, indices }
if !matches!(vector.dtype(), DType::Ptr { .. })
&& indices.iter().enumerate().all(|(i, &idx)| i == idx) && indices.len() == vector.dtype().vcount()
~> Arc::clone(vector),
gep @ Gep { vector, indices }
if !indices.is_empty()
&& gep.dtype().base() == ScalarDType::Index
&& !matches!(gep.dtype(), DType::Ptr { .. })
&& !matches!(vector.dtype(), DType::Ptr { .. })
&& matches!(vector.op(),
Op::Binary(..) | Op::Unary(..) | Op::Ternary(..)
| Op::Cast { .. } | Op::BitCast { .. })
=> {
let sources = vector.op().sources();
let new_sources: Vec<Arc<UOp>> = sources.iter().map(|s| s.gep(indices.clone())).collect();
let gep_count = indices.len();
let scalar_base = vector.dtype().base();
let result_dtype = DType::Scalar(scalar_base).vec(gep_count);
let new_op = match vector.op() {
Op::Cast { .. } => {
let scalar_dt = vector.dtype().scalar_dtype();
return Some(new_sources[0].cast(scalar_dt));
}
Op::BitCast { .. } => {
let scalar_dt = vector.dtype().scalar_dtype();
return Some(new_sources[0].bitcast(scalar_dt));
}
_ => vector.replace().dtype(result_dtype).src(new_sources).call(),
};
Some(new_op)
},
Cat { sources } if !matches!(sources.first().map(|s| s.dtype()), Some(DType::Ptr { .. })) => {
let elements: SmallVec<[Arc<UOp>; 4]> = sources.iter()
.flat_map(|s| (0..s.dtype().vcount()).map(move |i| s.gep(vec![i])))
.collect();
if elements.is_empty() { return None; }
Some(UOp::vectorize(elements))
},
Vectorize { elements }
if elements.len() > 1 && matches!(elements[0].op(), Op::Gep { .. })
=> {
let Op::Gep { vector: first_src, indices: first_idx } = elements[0].op() else { return None };
if first_idx.len() != 1 { return None; }
let mut combined = Vec::with_capacity(elements.len());
combined.push(first_idx[0]);
for elem in elements.iter().skip(1) {
let Op::Gep { vector, indices } = elem.op() else { return None };
if indices.len() != 1 || vector.id != first_src.id { return None; }
combined.push(indices[0]);
}
Some(first_src.gep(combined))
},
Gep { vector: Wmma { a, b, c, metadata }, indices } if !indices.is_empty() => {
let out_sz: usize = metadata.upcast_axes.c.iter().map(|(_, s)| s).product();
if out_sz == 0 || indices.len() % out_sz != 0 { return None; }
let tile_idxs: Vec<usize> = indices.iter().step_by(out_sz).copied().collect();
for i in 1..out_sz {
let adjusted: Option<Vec<usize>> = indices
.iter()
.skip(i)
.step_by(out_sz)
.map(|&x| x.checked_sub(i))
.collect();
if adjusted.as_deref() != Some(tile_idxs.as_slice()) { return None; }
}
let map_source = |src: &Arc<UOp>, src_idx: usize| -> Arc<UOp> {
let ssz = metadata.upcast_axes.source_size(src_idx);
let mut src_indices = Vec::with_capacity(tile_idxs.len() * ssz);
for &w in &tile_idxs {
let group = w / out_sz;
let start = group * ssz;
src_indices.extend(start..start + ssz);
}
src.gep(src_indices)
};
let scalar_base = metadata.dtype_out.base();
let result_dtype = DType::Scalar(scalar_base).vec(indices.len());
Some(UOp::new(
Op::Wmma {
a: map_source(a, 0),
b: map_source(b, 1),
c: map_source(c, 2),
metadata: metadata.clone(),
},
result_dtype,
))
},
}
}
pub fn div_mod_recombine_dsl_patterns() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Add[Mod(x, n), Mul[Idiv(x, n), n]]
~> |x| Arc::clone(x),
Add[Mod(Idiv(x, a @const(a_val)), c @const(c_val)), Mul[Idiv(x, _b @const(b_val)), c]] => {
let a_int = a_val.try_int()?;
let c_int = c_val.try_int()?;
let b_int = b_val.try_int()?;
if a_int * c_int == b_int {
return x.try_div(a).ok();
}
None
},
Add[Mul[Mod(x, c1 @const(c1_val)), c2 @const(c2_val)], Mul[Idiv(x, c1), _c3 @const(c3_val)]] => {
let c1_int = c1_val.try_int()?;
let c2_int = c2_val.try_int()?;
let c3_int = c3_val.try_int()?;
if c1_int * c2_int == c3_int {
return Some(x.mul(c2));
}
None
},
Add[Add[y, Mul[Idiv(x, n), n]], Mod(x, n)] ~> y.add(x),
Add[Add[y, Mod(x, n)], Mul[Idiv(x, n), n]] ~> y.add(x),
Add[Add[y, Mul[Idiv(x, c1 @const(c1_val)), _c3 @const(c3_val)]], Mul[Mod(x, c1), c2 @const(c2_val)]] => {
let c1_int = c1_val.try_int()?;
let c2_int = c2_val.try_int()?;
let c3_int = c3_val.try_int()?;
(c1_int * c2_int == c3_int).then(|| y.add(&x.mul(c2)))
},
Add[Add[y, Mul[Mod(x, c1 @const(c1_val)), c2 @const(c2_val)]], Mul[Idiv(x, c1), _c3 @const(c3_val)]] => {
let c1_int = c1_val.try_int()?;
let c2_int = c2_val.try_int()?;
let c3_int = c3_val.try_int()?;
(c1_int * c2_int == c3_int).then(|| y.add(&x.mul(c2)))
},
}
}
pub fn long_to_int_narrowing_patterns() -> &'static TypedPatternMatcher {
use morok_ir::uop::properties::SoundVminVmaxProperty;
fn fits_i32(uop: &Arc<UOp>) -> bool {
let Some((vmin, vmax)) = SoundVminVmaxProperty::get(uop) else { return false };
matches!(
(vmin, vmax),
(ConstValue::Int(min), ConstValue::Int(max))
if *min >= i32::MIN as i64 && *max <= i32::MAX as i64
)
}
crate::cached_patterns! {
for op in binary [Add, Mul, Sub, Mod, Max, Idiv, And, Or, Xor, Shl, Shr] {
result @ op(x, y)
if x.dtype() == DType::Scalar(ScalarDType::Int64)
&& fits_i32(x) && fits_i32(y) && fits_i32(result)
=> {
let i32_dt = DType::Scalar(ScalarDType::Int32);
let i64_dt = DType::Scalar(ScalarDType::Int64);
let x32 = x.cast(i32_dt.clone());
let y32 = y.cast(i32_dt.clone());
let r32 = UOp::new(Op::Binary(op, x32, y32), i32_dt);
Some(r32.cast(i64_dt))
},
},
Cast { src: Add(x, c @const(_cv)), dtype: cast_dt }
if x.dtype() == DType::Index && cast_dt.scalar().is_some_and(|s| s.is_signed() && s.is_int())
=> x.cast(cast_dt.clone()).try_add(&c.cast(cast_dt.clone())).ok(),
}
}