use std::sync::Arc;
use morok_ir::{BinaryOp, ConstValue, DType, Op, UOp};
use smallvec::SmallVec;
use tracing::trace;
use crate::TypedPatternMatcher;
pub fn count_divmod(uop: &Arc<UOp>) -> usize {
uop.toposort().iter().filter(|n| matches!(n.op(), Op::Binary(BinaryOp::Idiv | BinaryOp::Mod, _, _))).count()
}
pub fn extract_index_dimension(idx_uop: &Arc<UOp>) -> Option<i64> {
if let Op::Ternary(morok_ir::TernaryOp::Where, cond, true_val, false_val) = idx_uop.op()
&& matches!(false_val.op(), Op::Invalid)
{
return extract_dim_from_validity(cond, true_val);
}
if let Op::Range { end, .. } = idx_uop.op() {
if let Op::Const(cv) = end.op()
&& let ConstValue::Int(size) = cv.0
{
return Some(size);
}
return None; }
if let Op::DefineVar { max_val, .. } = idx_uop.op() {
return Some(*max_val + 1);
}
let mut product = 1i64;
let mut found_range = false;
for node in idx_uop.toposort() {
if let Op::Range { end, .. } = node.op() {
if let Op::Const(cv) = end.op()
&& let ConstValue::Int(size) = cv.0
{
product *= size;
found_range = true;
} else {
return None; }
}
}
if found_range && product > 0 {
Some(product)
} else {
match (idx_uop.vmin(), idx_uop.vmax()) {
(ConstValue::Int(min), ConstValue::Int(max)) if max >= min => Some(max - min + 1),
_ => None,
}
}
}
fn extract_dim_from_validity(cond: &Arc<UOp>, true_val: &Arc<UOp>) -> Option<i64> {
if let Op::Binary(BinaryOp::Lt, _rng, upper) = cond.op()
&& let Some(u) = const_int(upper)
{
return Some(u);
}
if let Op::Binary(BinaryOp::And, left, right) = cond.op()
&& let Some((begin, upper)) = extract_ge_lt_bounds(left, right).or_else(|| extract_ge_lt_bounds(right, left))
{
return Some(upper - begin);
}
if let Op::Binary(BinaryOp::Ge, rng, begin_uop) = cond.op()
&& let Some(begin) = const_int(begin_uop)
&& let Op::Range { end, .. } = rng.op()
&& let Some(rng_end) = const_int(end)
{
return Some(rng_end - begin);
}
match (true_val.vmin(), true_val.vmax()) {
(ConstValue::Int(min), ConstValue::Int(max)) if max >= min => Some(max - min + 1),
_ => None,
}
}
fn const_int(uop: &Arc<UOp>) -> Option<i64> {
if let Op::Const(cv) = uop.op()
&& let ConstValue::Int(v) = cv.0
{
return Some(v);
}
None
}
fn extract_ge_lt_bounds(maybe_ge: &Arc<UOp>, maybe_lt: &Arc<UOp>) -> Option<(i64, i64)> {
let Op::Binary(BinaryOp::Ge, range_ge, begin_uop) = maybe_ge.op() else { return None };
let Op::Binary(BinaryOp::Lt, range_lt, upper_uop) = maybe_lt.op() else { return None };
if !Arc::ptr_eq(range_ge, range_lt) {
return None;
}
let begin = const_int(begin_uop)?;
let upper = const_int(upper_uop)?;
Some((begin, upper))
}
pub fn compute_row_major_strides(dims: &[i64]) -> Vec<i64> {
let mut strides = vec![1i64; dims.len()];
for i in (0..dims.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * dims[i + 1];
}
strides
}
fn any_index_vectorized(indices: &[Arc<UOp>]) -> bool {
indices.iter().any(|idx| idx.dtype().vcount() > 1)
}
fn get_vector_count(indices: &[Arc<UOp>]) -> usize {
indices
.iter()
.find_map(|idx| {
let vc = idx.dtype().vcount();
if vc > 1 { Some(vc) } else { None }
})
.unwrap_or(1)
}
pub fn build_linear_index(indices: &[Arc<UOp>], strides: &[i64]) -> Arc<UOp> {
let mut linear = UOp::index_const(0);
for (idx, &stride) in indices.iter().zip(strides.iter()) {
if stride == 0 {
continue;
}
let term = if stride == 1 {
idx.clone()
} else {
let stride_const = UOp::index_const(stride);
UOp::new(Op::Binary(BinaryOp::Mul, idx.clone(), stride_const), DType::Index)
};
if let Op::Const(cv) = linear.op()
&& matches!(cv.0, ConstValue::Int(0))
{
linear = term;
} else {
linear = UOp::new(Op::Binary(BinaryOp::Add, linear, term), DType::Index);
}
}
linear
}
fn build_vectorized_linear_index(indices: &[Arc<UOp>], strides: &[i64], vcount: usize) -> Arc<UOp> {
let lane_indices: SmallVec<[Arc<UOp>; 4]> = (0..vcount)
.map(|lane| {
let scalar_indices: Vec<Arc<UOp>> = indices
.iter()
.map(|idx| {
if idx.dtype().vcount() > 1 {
idx.gep(vec![lane])
} else {
idx.clone()
}
})
.collect();
build_linear_index(&scalar_indices, strides)
})
.collect();
UOp::vectorize(lane_indices)
}
pub fn pm_linearize_multi_index() -> &'static TypedPatternMatcher<()> {
crate::cached_patterns! {
idx @ Index { buffer, indices, gate } if indices.len() > 1 => |idx, buffer, indices, gate| {
let dims: Option<Vec<i64>> = indices
.iter()
.map(extract_index_dimension)
.collect();
let dims = match dims {
Some(d) => d,
None => {
trace!(
uop_id = idx.id,
buffer_id = buffer.id,
"linearize_multi_index: couldn't extract all dimensions, skipping"
);
return None;
}
};
let strides = compute_row_major_strides(&dims);
let is_vectorized = any_index_vectorized(indices);
let linear_index = if is_vectorized {
let vcount = get_vector_count(indices);
build_vectorized_linear_index(indices, &strides, vcount)
} else {
build_linear_index(indices, &strides)
};
let original_divmod: usize = indices.iter().map(count_divmod).sum();
let linearized_divmod = count_divmod(&linear_index);
if linearized_divmod > original_divmod {
trace!(
uop_id = idx.id,
original_divmod,
linearized_divmod,
"linearize_multi_index: rejected (would increase divmod), keeping multi-index"
);
return None;
}
trace!(
uop_id = idx.id,
index_dims = ?dims,
original_divmod,
linearized_divmod,
"linearize_multi_index: linearizing {}-dimensional index",
indices.len()
);
let new_op = Op::Index {
buffer: buffer.clone(),
indices: smallvec::smallvec![linear_index],
gate: gate.clone(),
};
Some(UOp::new(new_op, idx.dtype().clone()))
},
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_row_major_strides() {
assert_eq!(compute_row_major_strides(&[2, 3, 4]), vec![12, 4, 1]);
assert_eq!(compute_row_major_strides(&[5, 10]), vec![10, 1]);
assert_eq!(compute_row_major_strides(&[100]), vec![1]);
}
#[test]
fn test_build_linear_index() {
let i = UOp::index_const(2);
let j = UOp::index_const(3);
let linear = build_linear_index(&[i, j], &[10, 1]);
assert!(matches!(linear.op(), Op::Binary(BinaryOp::Add, _, _)));
}
#[test]
fn test_extract_index_dimension_range() {
use morok_ir::AxisId;
let end = UOp::index_const(10);
let range = UOp::range_axis(end, AxisId::Renumbered(0), morok_ir::AxisType::Loop);
let dim = extract_index_dimension(&range);
assert_eq!(dim, Some(10));
}
#[test]
fn test_extract_index_dimension_complex_expression() {
use morok_ir::AxisId;
let r1 = UOp::range_axis(UOp::index_const(4), AxisId::Renumbered(0), morok_ir::AxisType::Loop);
let r2 = UOp::range_axis(UOp::index_const(8), AxisId::Renumbered(1), morok_ir::AxisType::Loop);
let stride = UOp::index_const(8);
let mul = UOp::new(Op::Binary(BinaryOp::Mul, r1, stride), DType::Index);
let add = UOp::new(Op::Binary(BinaryOp::Add, mul, r2), DType::Index);
let dim = extract_index_dimension(&add);
assert_eq!(dim, Some(32));
}
}