use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use morok_dtype::DType;
use morok_ir::types::{BinaryOp, ConstValue};
use morok_ir::uop::cached_property::CachedProperty;
use morok_ir::uop::properties::VminVmaxProperty;
use morok_ir::{Op, UOp, UOpKey};
use crate::TypedPatternMatcher;
fn parse_valid(v: &Arc<UOp>) -> Option<(Arc<UOp>, bool, i64)> {
if let Op::Binary(BinaryOp::Ne, lhs, rhs) = v.op()
&& let Op::Const(cv) = rhs.op()
&& cv.0 == ConstValue::Bool(true)
&& let Op::Binary(BinaryOp::Lt, x, c) = lhs.op()
&& x.dtype().is_int()
{
let (_, c_vmax) = VminVmaxProperty::get(c);
if let ConstValue::Int(c_val) = c_vmax {
return Some((x.clone(), false, *c_val));
}
}
if let Op::Unary(morok_ir::types::UnaryOp::Not, inner) = v.op()
&& let Op::Binary(BinaryOp::Lt, x, c) = inner.op()
&& x.dtype().is_int()
{
let (_, c_vmax) = VminVmaxProperty::get(c);
if let ConstValue::Int(c_val) = c_vmax {
return Some((x.clone(), false, *c_val));
}
}
if let Op::Binary(BinaryOp::Lt, x, c) = v.op()
&& x.dtype().is_int()
{
let (_, c_vmax) = VminVmaxProperty::get(c);
if let ConstValue::Int(c_val) = c_vmax {
return Some((x.clone(), true, *c_val - 1));
}
}
None
}
fn is_irreducible(op: &Op) -> bool {
matches!(op, Op::Const(..) | Op::DefineVar { .. } | Op::Special { .. } | Op::Range { .. })
}
fn split_add(expr: &Arc<UOp>) -> Vec<Arc<UOp>> {
match expr.op() {
Op::Binary(BinaryOp::Add, left, right) => {
let mut result = split_add(left);
result.extend(split_add(right));
result
}
_ => vec![expr.clone()],
}
}
fn split_and(cond: &Arc<UOp>) -> Vec<Arc<UOp>> {
match cond.op() {
Op::Binary(BinaryOp::And, left, right) => {
let mut result = split_and(left);
result.extend(split_and(right));
result
}
_ => vec![cond.clone()],
}
}
fn join_and(clauses: &[Arc<UOp>]) -> Arc<UOp> {
if clauses.is_empty() {
return UOp::const_(DType::Bool, ConstValue::Bool(true));
}
clauses.iter().cloned().reduce(|a, b| a.and_(&b)).unwrap()
}
pub fn simplify_valid(valid: &Arc<UOp>) -> Option<Arc<UOp>> {
if valid.has_index_in_sources() {
return None;
}
let mut clauses = split_and(valid);
if !clauses.iter().any(|c| parse_valid(c).is_some()) {
return None;
}
let clauses_snapshot = clauses.clone();
let backward_slices: Vec<&HashSet<u64>> = clauses_snapshot.iter().map(|c| c.backward_slice_ids()).collect();
clauses.sort_by_key(|v| {
let Some((expr, _, _)) = parse_valid(v) else { return 0i32 };
let expr_id = expr.id;
let mut priority = 0i32;
for (i, other) in clauses_snapshot.iter().enumerate() {
if other.id == v.id {
continue;
}
if backward_slices[i].contains(&expr_id) {
priority -= 1;
}
}
priority
});
let sorted_valids = clauses.clone();
let mut seen = std::collections::HashSet::new();
clauses.retain(|c| seen.insert(c.id));
let mut ret: Vec<Arc<UOp>> = Vec::new();
for stmt in &clauses {
let simplified = if ret.is_empty() {
stmt.clone()
} else {
let accumulated_valid = join_and(&ret);
uop_given_valid(&accumulated_valid, stmt, true)
};
ret.push(simplified);
}
if ret.len() == sorted_valids.len() && ret.iter().zip(sorted_valids.iter()).all(|(a, b)| a.id == b.id) {
return None;
}
Some(join_and(&ret))
}
fn uop_given_valid(valid: &Arc<UOp>, uop: &Arc<UOp>, try_simplex: bool) -> Arc<UOp> {
use morok_ir::rewrite::graph_rewrite;
type BoundsEntry = (Arc<UOp>, Option<i64>, Option<i64>);
let mut bounds: HashMap<u64, BoundsEntry> = HashMap::new();
for stmt in split_and(valid) {
if let Some((expr, is_upper, c)) = parse_valid(&stmt) {
let entry = bounds.entry(expr.id).or_insert_with(|| (expr.clone(), None, None));
if is_upper {
match entry.2 {
None => entry.2 = Some(c),
Some(existing) if c < existing => entry.2 = Some(c),
_ => {}
}
} else {
match entry.1 {
None => entry.1 = Some(c),
Some(existing) if c > existing => entry.1 = Some(c),
_ => {}
}
}
}
}
if bounds.is_empty() {
return uop.clone();
}
let mut all_candidates: Vec<(Arc<UOp>, Arc<UOp>)> = Vec::new();
let mut uop = uop.clone();
for (i, (_id, (expr, lower, upper))) in bounds.iter().enumerate() {
let (expr_vmin, expr_vmax) = VminVmaxProperty::get(expr);
let v0 = lower.unwrap_or_else(|| if let ConstValue::Int(v) = expr_vmin { *v } else { i64::MIN });
let v1 = upper.unwrap_or_else(|| if let ConstValue::Int(v) = expr_vmax { *v } else { i64::MAX });
let orig_min = if let ConstValue::Int(v) = expr_vmin { *v } else { i64::MIN };
let orig_max = if let ConstValue::Int(v) = expr_vmax { *v } else { i64::MAX };
if v0 == orig_min && v1 == orig_max {
continue;
}
let fake_var = UOp::define_var(format!("_valid_fake{i}"), v0, v1);
let fake_var = if expr.dtype() != fake_var.dtype() { fake_var.cast(expr.dtype()) } else { fake_var };
all_candidates.push((expr.clone(), fake_var));
if try_simplex {
let mut candidate_sets: Vec<Vec<(Arc<UOp>, Arc<UOp>)>> = vec![vec![all_candidates.last().unwrap().clone()]];
if let Op::Binary(BinaryOp::Add, ..) = expr.op()
&& v0 == 1
{
let addends = split_add(expr);
let all_irreducible_nonneg = addends.iter().all(|u| {
is_irreducible(u.op()) && {
let (vmin, _) = VminVmaxProperty::get(u);
matches!(vmin, ConstValue::Int(v) if *v >= 0)
}
});
if all_irreducible_nonneg {
let simplex_candidates: Vec<(Arc<UOp>, Arc<UOp>)> = addends
.iter()
.enumerate()
.map(|(j, xi)| {
let (_, xi_vmax) = VminVmaxProperty::get(xi);
let max_val = if let ConstValue::Int(v) = xi_vmax { *v } else { i64::MAX };
let fake = UOp::define_var(format!("_simplex_fake{j}"), 1, max_val);
let fake = if xi.dtype() != fake.dtype() { fake.cast(xi.dtype()) } else { fake };
(xi.clone(), fake)
})
.collect();
candidate_sets.push(simplex_candidates);
}
}
for candidates in &candidate_sets {
let new_uops: Vec<Arc<UOp>> = candidates
.iter()
.map(|(x, new_x)| {
#[allow(clippy::mutable_key_type)]
let map: HashMap<UOpKey, Arc<UOp>> = [(UOpKey(x.clone()), new_x.clone())].into();
uop.substitute(&map)
})
.collect();
if new_uops.iter().any(|u| Arc::ptr_eq(u, &uop)) {
continue;
}
let simplified: Vec<Arc<UOp>> = candidates
.iter()
.zip(new_uops.iter())
.map(|((x, new_x), u)| {
let s = graph_rewrite(crate::symbolic::patterns::symbolic(), u.clone(), &mut ());
#[allow(clippy::mutable_key_type)]
let rev: HashMap<UOpKey, Arc<UOp>> = [(UOpKey(new_x.clone()), x.clone())].into();
graph_rewrite(crate::symbolic::patterns::symbolic(), s.substitute(&rev), &mut ())
})
.collect();
if simplified.windows(2).all(|w| w[0].id == w[1].id) {
uop = simplified[0].clone();
}
}
}
}
if all_candidates.is_empty() {
return uop;
}
#[allow(clippy::mutable_key_type)]
let sub_map: HashMap<UOpKey, Arc<UOp>> =
all_candidates.iter().map(|(x, f)| (UOpKey(x.clone()), f.clone())).collect();
let substituted = uop.substitute(&sub_map);
if Arc::ptr_eq(&substituted, &uop) {
return uop;
}
let simplified = graph_rewrite(crate::symbolic::patterns::symbolic(), substituted, &mut ());
#[allow(clippy::mutable_key_type)]
let reverse_map: HashMap<UOpKey, Arc<UOp>> =
all_candidates.iter().map(|(x, f)| (UOpKey(f.clone()), x.clone())).collect();
let result = simplified.substitute(&reverse_map);
graph_rewrite(crate::symbolic::patterns::symbolic(), result, &mut ())
}
pub fn gated_given_valid(cond: &Arc<UOp>, x: &Arc<UOp>, invalid: &Arc<UOp>) -> Option<Arc<UOp>> {
let new_x = uop_given_valid(cond, x, false);
if new_x.id == x.id {
return None;
}
UOp::try_where(cond.clone(), new_x, invalid.clone()).ok()
}
pub fn pm_simplify_valid() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
valid @ And(_, _) if valid.dtype() == DType::Bool
=> |valid| simplify_valid(valid),
Where(cond, x, inv) if matches!(inv.op(), Op::Invalid)
=> |cond, x, inv| gated_given_valid(cond, x, inv),
}
}
fn drop_and_clauses(cond: &Arc<UOp>, x: &Arc<UOp>, invalid: &Arc<UOp>) -> Option<Arc<UOp>> {
use morok_ir::types::BinaryOp;
let clauses = cond.split_uop(BinaryOp::And);
if clauses.len() <= 1 {
return None;
}
let x_range_ids: HashSet<u64> = x.ranges().iter().map(|r| r.id).collect();
let mut keep = Vec::new();
let mut dropped = false;
for clause in &clauses {
let clause_ranges = clause.ranges();
if clause_ranges.iter().any(|r| x_range_ids.contains(&r.id)) {
keep.push(clause.clone());
} else {
dropped = true;
}
}
if !dropped {
return None;
}
if keep.is_empty() {
return None;
}
let new_cond = {
let mut acc = keep[0].clone();
for k in &keep[1..] {
acc = acc.try_and_op(k).ok()?;
}
acc
};
UOp::try_where(new_cond, x.clone(), invalid.clone()).ok()
}
pub fn pm_drop_and_clauses() -> &'static TypedPatternMatcher {
crate::cached_patterns! {
Where(cond, x, inv) if matches!(inv.op(), Op::Invalid)
=> |cond, x, inv| drop_and_clauses(cond, x, inv),
}
}