use std::f32::consts::PI;
use std::sync::Arc;
use morok_dtype::DType;
use morok_ir::{AxisId, AxisType, BufferizeOpts, ConstValue, Op, UOp};
use crate::pattern::RewriteResult;
use crate::rangeify::IndexingContext;
use crate::rangeify::patterns;
#[test]
fn test_early_rewrites_detach_removal() {
let matcher = patterns::early_rewrites();
let x = UOp::native_const(42.0f32);
let detach = x.detach();
let result = matcher.rewrite(&detach, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)), "Should rewrite DETACH");
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &x), "Should return the source");
}
}
#[test]
fn test_early_rewrites_contiguous_backward_removal() {
let matcher = patterns::early_rewrites();
let x = UOp::native_const(PI);
let contiguous = x.contiguous_backward();
let result = matcher.rewrite(&contiguous, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)), "Should rewrite CONTIGUOUS_BACKWARD");
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &x), "Should return the source");
}
}
#[test]
fn test_early_rewrites_no_match_for_other_ops() {
let matcher = patterns::early_rewrites();
let const_op = UOp::native_const(1.0f32);
let result = matcher.rewrite(&const_op, &mut ());
assert!(matches!(result, RewriteResult::NoMatch), "Should not match CONST");
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let add = a.try_add(&b).unwrap();
let result = matcher.rewrite(&add, &mut ());
assert!(matches!(result, RewriteResult::NoMatch), "Should not match Binary ops");
}
#[test]
fn test_early_rewrites_nested_detach() {
let matcher = patterns::early_rewrites();
let x = UOp::native_const(1.0f32);
let inner_detach = x.detach();
let outer_detach = inner_detach.detach();
let result = matcher.rewrite(&outer_detach, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &inner_detach), "Should unwrap outer DETACH to inner DETACH");
}
}
#[test]
fn test_buffer_folding_noop_bufferize() {
let matcher = patterns::buffer_folding();
let x = UOp::native_const(1.0f32);
let range_end = UOp::index_const(10);
let range = UOp::range_axis(range_end, AxisId::Renumbered(0), AxisType::Loop);
let bufferize = UOp::bufferize(x.clone(), vec![range.clone()], BufferizeOpts::local());
let index = UOp::index().buffer(bufferize).indices(vec![range]).call().unwrap();
let result = matcher.rewrite(&index, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)), "Should remove noop BUFFERIZE");
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &x), "Should return the compute directly");
}
}
#[test]
fn test_buffer_folding_bufferize_const() {
let matcher = patterns::buffer_folding();
let const_val = UOp::native_const(42.0f32);
let range_end = UOp::index_const(10);
let range = UOp::range_axis(range_end, AxisId::Renumbered(0), AxisType::Loop);
let bufferize = UOp::bufferize(const_val.clone(), vec![range], BufferizeOpts::local());
let result = matcher.rewrite(&bufferize, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)), "Should remove BUFFERIZE from CONST");
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &const_val), "Should return the constant directly");
}
}
#[test]
fn test_buffer_folding_index_const() {
let matcher = patterns::buffer_folding();
let const_val = UOp::native_const(PI);
let range_end = UOp::index_const(10);
let range = UOp::range_axis(range_end, AxisId::Renumbered(0), AxisType::Loop);
let index = UOp::index().buffer(const_val.clone()).indices(vec![range]).call().unwrap();
let result = matcher.rewrite(&index, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)), "Should remove INDEX from CONST");
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &const_val), "Should return the constant directly");
}
}
#[test]
fn test_buffer_folding_copy_const() {
let matcher = patterns::buffer_folding();
let const_val = UOp::native_const(1.0f32);
let device = UOp::device(morok_ir::DeviceSpec::Cpu);
let copy = const_val.copy(device);
let result = matcher.rewrite(©, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)), "Should remove COPY from CONST");
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &const_val), "Should return the constant directly");
}
}
#[test]
fn test_buffer_folding_no_match_different_ranges() {
let matcher = patterns::buffer_folding();
let x = UOp::native_const(1.0f32);
let range1_end = UOp::index_const(10);
let range1 = UOp::range_axis(range1_end, AxisId::Renumbered(0), AxisType::Loop);
let range2_end = UOp::index_const(20);
let range2 = UOp::range_axis(range2_end, AxisId::Renumbered(1), AxisType::Loop);
let bufferize = UOp::bufferize(x, vec![range1], BufferizeOpts::local());
let index = UOp::index().buffer(bufferize).indices(vec![range2]).call().unwrap();
let result = matcher.rewrite(&index, &mut ());
match result {
RewriteResult::NoMatch => {}
RewriteResult::Rewritten(rewritten) => {
assert!(!matches!(rewritten.op(), Op::Const(_)));
}
RewriteResult::Gate(_) => {}
}
}
#[test]
fn test_dead_axis_removal_single_dead_axis() {
let matcher = patterns::dead_axis_removal();
let x = UOp::native_const(1.0f32);
let dead_range_end = UOp::index_const(1); let dead_range = UOp::range_axis(dead_range_end, AxisId::Renumbered(0), AxisType::Loop);
let bufferize = UOp::bufferize(x.clone(), vec![dead_range], BufferizeOpts::local());
let result = matcher.rewrite(&bufferize, &mut ());
match result {
RewriteResult::Rewritten(rewritten) => {
let reshape_op = match rewritten.op() {
Op::Expand { src, .. } => src,
Op::Reshape { .. } => &rewritten,
_ => panic!("Expected EXPAND or RESHAPE, got: {}", rewritten.tree()),
};
if let Op::Reshape { src: bufferize_op, .. } = reshape_op.op() {
assert!(
matches!(bufferize_op.op(), Op::Bufferize { ranges, .. } if ranges.is_empty()),
"Inner should be BUFFERIZE with no ranges, got: {}",
rewritten.tree()
);
} else {
panic!("Expected RESHAPE inside result, got: {}", rewritten.tree());
}
}
_ => {
}
}
}
#[test]
fn test_dead_axis_removal_mixed_axes() {
let matcher = patterns::dead_axis_removal();
let x = UOp::native_const(1.0f32);
let live_range_end = UOp::index_const(10);
let live_range = UOp::range_axis(live_range_end, AxisId::Renumbered(0), AxisType::Loop);
let dead_range_end = UOp::index_const(1);
let dead_range = UOp::range_axis(dead_range_end, AxisId::Renumbered(1), AxisType::Loop);
let bufferize = UOp::bufferize(x.clone(), vec![live_range.clone(), dead_range], BufferizeOpts::local());
let result = matcher.rewrite(&bufferize, &mut ());
match result {
RewriteResult::Rewritten(rewritten) => {
if let Op::Expand { src: reshape_op, .. } = rewritten.op() {
if let Op::Reshape { src: bufferize_op, .. } = reshape_op.op() {
assert!(
matches!(bufferize_op.op(), Op::Bufferize { ranges, .. } if ranges.is_empty()),
"Inner should be BUFFERIZE with no ranges, got: {}",
rewritten.tree()
);
} else {
panic!("Expected RESHAPE inside EXPAND, got: {}", rewritten.tree());
}
} else {
panic!("Expected EXPAND(RESHAPE(BUFFERIZE_no_ranges)), got: {}", rewritten.tree());
}
}
_ => {
panic!("Expected pattern to match and rewrite");
}
}
}
#[test]
fn test_dead_axis_removal_no_dead_axes_simple_compute() {
let matcher = patterns::dead_axis_removal();
let x = UOp::native_const(1.0f32);
let range1_end = UOp::index_const(10);
let range1 = UOp::range_axis(range1_end, AxisId::Renumbered(0), AxisType::Loop);
let range2_end = UOp::index_const(20);
let range2 = UOp::range_axis(range2_end, AxisId::Renumbered(1), AxisType::Loop);
let bufferize = UOp::bufferize(x.clone(), vec![range1, range2], BufferizeOpts::local());
let result = matcher.rewrite(&bufferize, &mut ());
match result {
RewriteResult::Rewritten(rewritten) => {
if let Op::Expand { src: reshape_op, .. } = rewritten.op() {
if let Op::Reshape { src: bufferize_op, .. } = reshape_op.op() {
assert!(
matches!(bufferize_op.op(), Op::Bufferize { ranges, .. } if ranges.is_empty()),
"Inner should be BUFFERIZE with no ranges, got: {}",
rewritten.tree()
);
} else {
panic!("Expected RESHAPE inside EXPAND, got: {}", rewritten.tree());
}
} else {
panic!("Expected EXPAND(RESHAPE(BUFFERIZE_no_ranges)), got: {}", rewritten.tree());
}
}
_ => panic!("Expected pattern to match and rewrite when all ranges are dead"),
}
}
#[test]
fn test_buffer_removal_cheap_compute() {
let matcher = patterns::buffer_removal();
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let add = a.try_add(&b).unwrap();
let range_end = UOp::index_const(10);
let range = UOp::range_axis(range_end, AxisId::Renumbered(0), AxisType::Loop);
let bufferize = UOp::bufferize(add.clone(), vec![range], BufferizeOpts::local());
let result = matcher.rewrite(&bufferize, &mut ());
match result {
RewriteResult::Rewritten(rewritten) => {
assert!(Arc::ptr_eq(&rewritten, &add), "Should remove BUFFERIZE from cheap compute");
}
_ => {
}
}
}
#[test]
fn test_buffer_removal_always_run_ops_kept() {
let matcher = patterns::buffer_removal();
let src = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let contiguous = src.contiguous();
let range_end = UOp::index_const(10);
let range = UOp::range_axis(range_end, AxisId::Renumbered(0), AxisType::Loop);
let bufferize = UOp::bufferize(contiguous, vec![range], BufferizeOpts::local());
let result = matcher.rewrite(&bufferize, &mut ());
assert!(
matches!(result, RewriteResult::NoMatch),
"BUFFERIZE(CONTIGUOUS) must be kept - always-run ops need their buffers"
);
}
#[test]
fn test_buffer_removal_nested_bufferize() {
let matcher = patterns::buffer_removal();
let x = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let range1_end = UOp::index_const(10);
let range1 = UOp::range_axis(range1_end, AxisId::Renumbered(0), AxisType::Loop);
let inner = UOp::bufferize(x.clone(), vec![range1], BufferizeOpts::local());
let range2_end = UOp::index_const(20);
let range2 = UOp::range_axis(range2_end, AxisId::Renumbered(1), AxisType::Loop);
let outer = UOp::bufferize(inner, vec![range2.clone()], BufferizeOpts::local());
let result = matcher.rewrite(&outer, &mut ());
match result {
RewriteResult::Rewritten(rewritten) => {
if let Op::Bufferize { compute, .. } = rewritten.op() {
assert!(Arc::ptr_eq(compute, &x), "Should have compute pointing to x, not inner BUFFERIZE");
} else {
panic!("Expected BUFFERIZE operation");
}
}
_ => {
}
}
}
#[test]
fn test_buffer_removal_no_match_expensive_compute() {
let matcher = patterns::buffer_removal();
let buffer = UOp::buffer_id(Some(0));
let index = UOp::index_const(0);
let load = UOp::load().buffer(buffer).index(index).call();
let range_end = UOp::index_const(10);
let range = UOp::range_axis(range_end, AxisId::Renumbered(0), AxisType::Loop);
let bufferize = UOp::bufferize(load, vec![range], BufferizeOpts::local());
let result = matcher.rewrite(&bufferize, &mut ());
assert!(matches!(result, RewriteResult::NoMatch), "Should not remove BUFFERIZE from expensive op");
}
#[test]
fn test_movement_op_removal_no_match_without_ranges() {
let matcher = patterns::apply_rangeify_patterns();
let mut ctx = IndexingContext::new();
let src = UOp::native_const(1.0f32);
let permute = UOp::new(Op::Permute { src: src.clone(), axes: vec![1, 0] }, DType::Float32);
let result = matcher.rewrite(&permute, &mut ctx);
assert!(matches!(result, RewriteResult::NoMatch), "Should NOT remove movement op without ranges assigned");
}
#[test]
fn test_movement_op_removal_removes_with_ranges() {
let matcher = patterns::apply_rangeify_patterns();
let mut ctx = IndexingContext::new();
let src = UOp::native_const(1.0f32);
let permute = UOp::new(Op::Permute { src: src.clone(), axes: vec![1, 0] }, DType::Float32);
let range = UOp::new(
Op::Range {
end: UOp::index_const(5),
axis_id: AxisId::Renumbered(0),
axis_type: AxisType::Loop,
deps: smallvec::SmallVec::new(),
},
DType::Index,
);
ctx.set_ranges(&permute, vec![range.clone()], vec![range.clone()]);
let result = matcher.rewrite(&permute, &mut ctx);
match result {
RewriteResult::Rewritten(result) => {
assert!(std::sync::Arc::ptr_eq(&result, &src), "Should return the source operand");
}
_ => panic!("Expected movement op to be removed when ranges are assigned"),
}
}
#[test]
fn test_movement_op_removal_reshape() {
let matcher = patterns::apply_rangeify_patterns();
let mut ctx = IndexingContext::new();
let src = UOp::native_const(1.0f32);
let new_shape = UOp::vectorize(smallvec::smallvec![UOp::index_const(4), UOp::index_const(8)]);
let reshape = UOp::new(Op::Reshape { src: src.clone(), new_shape }, DType::Float32);
let range = UOp::new(
Op::Range {
end: UOp::index_const(4),
axis_id: AxisId::Renumbered(0),
axis_type: AxisType::Loop,
deps: smallvec::SmallVec::new(),
},
DType::Index,
);
ctx.set_ranges(&reshape, vec![range.clone()], vec![range.clone()]);
let result = matcher.rewrite(&reshape, &mut ctx);
match result {
RewriteResult::Rewritten(result) => {
assert!(std::sync::Arc::ptr_eq(&result, &src), "RESHAPE should be removed");
}
_ => panic!("Expected RESHAPE to be removed when ranges are assigned"),
}
}
#[test]
fn test_movement_op_removal_expand() {
let matcher = patterns::apply_rangeify_patterns();
let mut ctx = IndexingContext::new();
let src = UOp::native_const(1.0f32);
let new_shape = UOp::vectorize(smallvec::smallvec![UOp::index_const(4), UOp::index_const(8)]);
let expand = UOp::new(Op::Expand { src: src.clone(), new_shape }, DType::Float32);
let range = UOp::new(
Op::Range {
end: UOp::index_const(4),
axis_id: AxisId::Renumbered(0),
axis_type: AxisType::Loop,
deps: smallvec::SmallVec::new(),
},
DType::Index,
);
ctx.set_ranges(&expand, vec![range.clone()], vec![range.clone()]);
let result = matcher.rewrite(&expand, &mut ctx);
match result {
RewriteResult::Rewritten(result) => {
assert!(std::sync::Arc::ptr_eq(&result, &src), "EXPAND should be removed");
}
_ => panic!("Expected EXPAND to be removed when ranges are assigned"),
}
}
#[test]
fn test_movement_op_removal_non_movement_op() {
let matcher = patterns::apply_rangeify_patterns();
let mut ctx = IndexingContext::new();
let src = UOp::native_const(1.0f32);
let sqrt = src.try_sqrt().unwrap();
let result = matcher.rewrite(&sqrt, &mut ctx);
assert!(matches!(result, RewriteResult::NoMatch), "Should not match non-movement ops without ranges");
}
#[test]
fn test_pattern_composition() {
let x = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let detach = x.detach();
let early = patterns::early_rewrites();
let result1 = early.rewrite(&detach, &mut ());
assert!(matches!(result1, RewriteResult::Rewritten(_)));
let unwrapped = if let RewriteResult::Rewritten(r) = result1 {
r
} else {
panic!("Should have rewritten");
};
let range_end = UOp::index_const(10);
let range = UOp::range_axis(range_end, AxisId::Renumbered(0), AxisType::Loop);
let bufferize = UOp::bufferize(unwrapped, vec![range], BufferizeOpts::local());
let folding = patterns::buffer_folding();
let result2 = folding.rewrite(&bufferize, &mut ());
match result2 {
RewriteResult::Rewritten(rewritten) => {
assert!(Arc::ptr_eq(&rewritten, &x), "Should have removed both DETACH and BUFFERIZE");
}
_ => {
}
}
}
#[test]
fn test_idempotent_patterns() {
let x = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let detach = x.detach();
let matcher = patterns::early_rewrites();
let result1 = matcher.rewrite(&detach, &mut ());
assert!(matches!(result1, RewriteResult::Rewritten(_)));
let unwrapped = if let RewriteResult::Rewritten(r) = result1 { r } else { x.clone() };
let result2 = matcher.rewrite(&unwrapped, &mut ());
assert!(matches!(result2, RewriteResult::NoMatch), "Should not match on already-processed node");
}