use std::sync::Arc;
use crate::rangeify::patterns::dead_axis_removal;
use crate::rewrite::graph_rewrite_bottom_up;
use morok_dtype::DType;
use morok_ir::{ConstValue, Op, UOp};
fn is_expand_reshape_bufferize(result: &Arc<UOp>) -> bool {
let reshape_op = match result.op() {
Op::Expand { src, .. } => src,
Op::Reshape { .. } => result,
_ => return false,
};
let Op::Reshape { src: bufferize_op, .. } = reshape_op.op() else {
return false;
};
let Op::Bufferize { ranges, .. } = bufferize_op.op() else {
return false;
};
ranges.is_empty()
}
fn get_inner_bufferize(result: &Arc<UOp>) -> Option<&Arc<UOp>> {
let reshape_op = match result.op() {
Op::Expand { src, .. } => src,
Op::Reshape { .. } => result,
_ => return None,
};
let Op::Reshape { src: bufferize_op, .. } = reshape_op.op() else {
return None;
};
if matches!(bufferize_op.op(), Op::Bufferize { .. }) { Some(bufferize_op) } else { None }
}
#[test]
fn test_bufferize_with_size_1_range() {
let x = UOp::param(1, 1, DType::Float32, None);
let dead_range = UOp::range_const(1, 0);
let bufferized = UOp::bufferize_global(x.clone(), vec![dead_range]);
let matcher = dead_axis_removal();
let result = graph_rewrite_bottom_up(&matcher, bufferized, &mut ());
assert!(
is_expand_reshape_bufferize(&result),
"BUFFERIZE with dead axis should be restructured to EXPAND(RESHAPE(BUFFERIZE_no_ranges)), got: {}",
result.tree()
);
if let Some(inner_buf) = get_inner_bufferize(&result)
&& let Op::Bufferize { compute, .. } = inner_buf.op()
{
assert!(Arc::ptr_eq(compute, &x), "Inner BUFFERIZE should have original compute");
}
}
#[test]
fn test_bufferize_all_dead_axes() {
let x = UOp::param(1, 1, DType::Float32, None);
let dead_ranges = vec![UOp::range_const(1, 0), UOp::range_const(1, 1), UOp::range_const(1, 2)];
let bufferized = UOp::bufferize_global(x.clone(), dead_ranges);
let matcher = dead_axis_removal();
let result = graph_rewrite_bottom_up(&matcher, bufferized, &mut ());
assert!(
is_expand_reshape_bufferize(&result),
"All dead axes should produce EXPAND(RESHAPE(BUFFERIZE_no_ranges)), got: {}",
result.tree()
);
}
#[test]
fn test_bufferize_mixed_live_dead_simple_compute() {
let x = UOp::param(1, 1, DType::Float32, None);
let range1 = UOp::range_const(10, 0);
let dead_range = UOp::range_const(1, 1);
let range2 = UOp::range_const(20, 2);
let bufferized = UOp::bufferize_global(x.clone(), vec![range1, dead_range, range2]);
let matcher = dead_axis_removal();
let result = graph_rewrite_bottom_up(&matcher, bufferized.clone(), &mut ());
assert!(
is_expand_reshape_bufferize(&result),
"When compute has no ranges, result should be EXPAND(RESHAPE(BUFFERIZE_no_ranges)), got: {}",
result.tree()
);
}
#[test]
fn test_bufferize_no_dead_axes_simple_compute() {
let x = UOp::param(1, 1, DType::Float32, None);
let ranges = vec![UOp::range_const(10, 0), UOp::range_const(20, 1)];
let bufferized = UOp::bufferize_global(x.clone(), ranges);
let matcher = dead_axis_removal();
let result = graph_rewrite_bottom_up(&matcher, bufferized.clone(), &mut ());
assert!(
is_expand_reshape_bufferize(&result),
"When compute has no ranges, result should be EXPAND(RESHAPE(BUFFERIZE_no_ranges)), got: {}",
result.tree()
);
}
#[test]
fn test_index_after_dead_axis_removal() {
use crate::rangeify::patterns::{buffer_folding, movement_op_patterns};
let x = UOp::param(1, 1, DType::Float32, None);
let live_range = UOp::range_const(10, 0);
let dead_range = UOp::range_const(1, 1);
let bufferized = UOp::bufferize_global(x.clone(), vec![live_range.clone(), dead_range.clone()]);
let idx1 = UOp::index_const(5);
let idx2 = UOp::index_const(0);
let indexed = UOp::index().buffer(bufferized).indices(vec![idx1.clone(), idx2]).call().unwrap();
let buffer_simplify = buffer_folding() + dead_axis_removal() + movement_op_patterns();
let result = graph_rewrite_bottom_up(&buffer_simplify, indexed, &mut ());
if let Op::Index { indices, .. } = result.op() {
assert_eq!(indices.len(), 1, "Should have 1 index after dead axis removal + movement ops");
}
}
#[test]
fn test_bufferize_dead_axis_with_constants() {
let x = UOp::param(1, 1, DType::Float32, None);
let dead_range_const = UOp::range_const(1, 0);
let bufferized = UOp::bufferize_global(x.clone(), vec![dead_range_const]);
let matcher = dead_axis_removal();
let result = graph_rewrite_bottom_up(&matcher, bufferized, &mut ());
assert!(
is_expand_reshape_bufferize(&result),
"Dead axis with constant 1 should produce EXPAND(RESHAPE(BUFFERIZE_no_ranges)), got: {}",
result.tree()
);
}
#[test]
fn test_multiple_dead_axis_removal_passes() {
let x = UOp::param(1, 1, DType::Float32, None);
let live_range = UOp::range_const(10, 0);
let dead_range1 = UOp::range_const(1, 1);
let dead_range2 = UOp::range_const(1, 2);
let bufferized = UOp::bufferize_global(x.clone(), vec![live_range.clone(), dead_range1, dead_range2]);
let matcher = dead_axis_removal();
let result1 = graph_rewrite_bottom_up(&matcher, bufferized.clone(), &mut ());
let result2 = graph_rewrite_bottom_up(&matcher, result1.clone(), &mut ());
assert_eq!(result1.tree(), result2.tree(), "Dead axis removal should be idempotent");
assert!(
is_expand_reshape_bufferize(&result1),
"Result should be EXPAND(RESHAPE(BUFFERIZE_no_ranges)), got: {}",
result1.tree()
);
}
#[test]
fn test_dead_axis_uint_constant() {
let x = UOp::param(1, 1, DType::Float32, None);
let const_end = UOp::const_(DType::Index, ConstValue::UInt(1));
let dead_range = UOp::range(const_end, 0);
let bufferized = UOp::bufferize_global(x.clone(), vec![dead_range]);
let matcher = dead_axis_removal();
let result = graph_rewrite_bottom_up(&matcher, bufferized, &mut ());
assert!(
is_expand_reshape_bufferize(&result),
"Dead axis with UInt(1) should produce EXPAND(RESHAPE(BUFFERIZE_no_ranges)), got: {}",
result.tree()
);
}