use std::sync::Arc;
use morok_dtype::{DType, ScalarDType};
use morok_ir::types::ConstValue;
use morok_ir::uop::cached_property::CachedProperty;
use morok_ir::uop::properties::VminVmaxProperty;
use morok_ir::{Op, UOp};
use crate::TypedPatternMatcher;
fn select_dtype(uop: &Arc<UOp>) -> DType {
let (vmin, vmax) = VminVmaxProperty::get(uop);
let fits_i32 = match (vmin, vmax) {
(ConstValue::Int(min), ConstValue::Int(max)) => *min >= i32::MIN as i64 && *max <= i32::MAX as i64,
(ConstValue::UInt(min), ConstValue::UInt(max)) => *min <= i32::MAX as u64 && *max <= i32::MAX as u64,
(ConstValue::Bool(_), ConstValue::Bool(_)) => true,
_ => false,
};
if fits_i32 { DType::Scalar(ScalarDType::Int32) } else { DType::Scalar(ScalarDType::Int64) }
}
fn least_upper_dtype(a: &DType, b: &DType) -> DType {
match (a, b) {
(DType::Scalar(ScalarDType::Int64), _) | (_, DType::Scalar(ScalarDType::Int64)) => {
DType::Scalar(ScalarDType::Int64)
}
_ => DType::Scalar(ScalarDType::Int32),
}
}
pub fn pm_lower_index_dtype() -> TypedPatternMatcher {
crate::patterns! {
c @const(cv) if c.dtype() == DType::Index => |c, cv| {
let dt = select_dtype(c);
Some(UOp::const_(dt, cv).cast(DType::Index))
},
vc @ VConst { values } if vc.dtype().base() == ScalarDType::Index => |vc, values| {
let dt = select_dtype(vc);
let vcount = vc.dtype().vcount();
let vec_dt = dt.vec(vcount);
let vec_index_dt = DType::Vector { scalar: ScalarDType::Index, count: vcount };
let new_vc = UOp::new(Op::VConst { values: values.clone() }, vec_dt);
Some(new_vc.cast(vec_index_dt))
},
dv @ DefineVar { name, min_val, max_val } if dv.dtype() == DType::Index => |dv, name, min_val, max_val| {
let dt = select_dtype(dv);
let var = UOp::new(Op::DefineVar { name: name.clone(), min_val: *min_val, max_val: *max_val }, dt);
Some(var.cast(DType::Index))
},
node if matches!(node.op(), Op::Unary(_, _)) && node.dtype() == DType::Index => |node| {
let Op::Unary(op, x) = node.op() else { return None };
let Op::Cast { src, dtype } = x.op() else { return None };
if *dtype != DType::Index { return None; }
let dt = least_upper_dtype(&select_dtype(node), &src.dtype());
let result = UOp::new(Op::Unary(*op, src.cast(dt.clone())), dt);
Some(result.cast(DType::Index))
},
node if matches!(node.op(), Op::Binary(_, _, _)) => |node| {
let Op::Binary(op, lhs, rhs) = node.op() else { return None };
let (Op::Cast { src: x, dtype: lhs_dt }, Op::Cast { src: y, dtype: rhs_dt }) = (lhs.op(), rhs.op()) else {
return None;
};
if *lhs_dt != DType::Index || *rhs_dt != DType::Index {
return None;
}
let result_dt = select_dtype(node);
let dt = least_upper_dtype(&result_dt, &least_upper_dtype(&x.dtype(), &y.dtype()));
let result = UOp::alu(*op, x.cast(dt.clone()), y.cast(dt));
Some(result.cast(node.dtype()))
},
Where(cond, true_val, false_val)
if true_val.dtype() == DType::Index && false_val.dtype() == DType::Index => |cond, true_val, false_val| {
let (Op::Cast { src: x, dtype: t_dt }, Op::Cast { src: y, dtype: f_dt }) = (true_val.op(), false_val.op()) else {
return None;
};
if *t_dt != DType::Index || *f_dt != DType::Index {
return None;
}
let dt = least_upper_dtype(&x.dtype(), &y.dtype());
let result = UOp::try_where(cond.clone(), x.cast(dt.clone()), y.cast(dt)).ok()?;
Some(result.cast(DType::Index))
},
range @ Range { end, axis_id, axis_type } if range.dtype() == DType::Index => |end, axis_id, axis_type| {
let Op::Cast { src: end_inner, dtype: end_dt } = end.op() else {
return None;
};
if *end_dt != DType::Index {
return None;
}
let dt = end_inner.dtype();
let result = UOp::new(Op::Range { end: end_inner.clone(), axis_id: *axis_id, axis_type: *axis_type, deps: smallvec::SmallVec::new() }, dt);
Some(result.cast(DType::Index))
},
special @ Special { name, end } if special.dtype() == DType::Index => |name, end| {
let Op::Cast { src: end_inner, dtype: end_dt } = end.op() else {
return None;
};
if *end_dt != DType::Index {
return None;
}
let i32_dt = DType::Scalar(ScalarDType::Int32);
let result = UOp::new(Op::Special { end: end_inner.clone(), name: name.clone() }, i32_dt);
Some(result.cast(DType::Index))
},
vec @ Vectorize { elements } if vec.dtype().base() == ScalarDType::Index => |vec, elements| {
let inner: Option<Vec<_>> = elements.iter().map(|e| {
match e.op() {
Op::Cast { src, dtype } if *dtype == DType::Index => Some(src.clone()),
_ => None,
}
}).collect();
let inner = inner?;
let dt = select_dtype(vec);
let casted: Vec<_> = inner.iter().map(|e| e.cast(dt.clone())).collect();
let vec_index_dt = DType::Vector { scalar: ScalarDType::Index, count: elements.len() };
Some(UOp::vectorize(casted.into()).cast(vec_index_dt))
},
Bind { var, value } if var.dtype() == DType::Index => |var, value| {
let Op::Cast { src: var_inner, dtype: var_dt } = var.op() else { return None };
let Op::Cast { src: val_inner, dtype: val_dt } = value.op() else { return None };
if *var_dt != DType::Index || *val_dt != DType::Index { return None; }
let dt = least_upper_dtype(&var_inner.dtype(), &val_inner.dtype());
let bound = var_inner.cast(dt.clone()).bind(val_inner.cast(dt));
Some(bound.cast(DType::Index))
},
node @ Index { buffer, indices, gate } => |node, buffer, indices, gate| {
let mut changed = false;
let new_indices: smallvec::SmallVec<[std::sync::Arc<UOp>; 4]> = indices.iter().map(|idx| {
if let Op::Cast { src, dtype } = idx.op()
&& *dtype == DType::Index && src.dtype().is_int() {
changed = true;
return src.clone();
}
idx.clone()
}).collect();
if !changed { return None; }
Some(UOp::new(Op::Index { buffer: buffer.clone(), indices: new_indices, gate: gate.clone() }, node.dtype()))
},
node @ Index { buffer, indices, gate } => |node, buffer, indices, gate| {
let mut new_indices: smallvec::SmallVec<[std::sync::Arc<UOp>; 4]> = smallvec::SmallVec::new();
let mut conds: Vec<std::sync::Arc<UOp>> = Vec::new();
let mut changed = false;
for idx in indices.iter() {
if let Op::Ternary(morok_ir::TernaryOp::Where, cond, true_val, false_val) = idx.op()
&& UOp::is_invalid_marker(false_val) {
new_indices.push(true_val.clone());
conds.push(cond.clone());
changed = true;
continue;
}
new_indices.push(idx.clone());
}
if !changed { return None; }
let extracted = conds.into_iter().reduce(|a, b| a.try_and_op(&b).expect("ICE: AND gate merge"));
let new_gate = match (gate, extracted) {
(Some(existing), Some(ext)) => Some(existing.try_and_op(&ext).expect("ICE: AND gate merge")),
(Some(existing), None) => Some(existing.clone()),
(None, ext) => ext,
};
Some(UOp::new(Op::Index { buffer: buffer.clone(), indices: new_indices, gate: new_gate }, node.dtype()))
},
Sink { sources } => |sources| {
let mut changed = false;
let new_sources: Vec<Arc<UOp>> = sources.iter().map(|s| {
if let Op::Cast { src, dtype } = s.op() && *dtype == DType::Index {
changed = true;
src.clone()
} else {
s.clone()
}
}).collect();
if !changed { return None; }
Some(UOp::sink(new_sources))
},
End { computation, ranges } => |computation, ranges| {
let Op::Cast { src, dtype } = computation.op() else { return None };
if *dtype != DType::Index { return None; }
Some(UOp::new(Op::End { computation: src.clone(), ranges: ranges.clone() }, DType::Void))
},
}
}