use std::sync::Arc;
use morok_device::DeviceSpec;
use morok_dtype::DType;
use morok_ir::{AxisId, AxisType, Op, SInt, UOp};
use crate::rangeify::patterns::movement_op_patterns;
use crate::rewrite::graph_rewrite;
fn create_buffer(size: usize) -> Arc<UOp> {
UOp::new_buffer(DeviceSpec::Cpu, size, DType::Float32)
}
fn create_range(size: usize, axis_id: usize) -> Arc<UOp> {
let end = UOp::index_const(size as i64);
UOp::range_axis(end, AxisId::Renumbered(axis_id), AxisType::Loop)
}
#[test]
fn test_expand_index_transformation() {
#[allow(clippy::identity_op)]
let buffer = create_buffer(10 * 1 * 20);
let reshaped =
buffer.try_reshape(&vec![SInt::Const(10), SInt::Const(1), SInt::Const(20)].into_iter().collect()).unwrap();
let shape2 = UOp::vectorize(vec![UOp::index_const(10), UOp::index_const(5), UOp::index_const(20)].into());
let expanded = UOp::new(Op::Expand { src: reshaped, new_shape: shape2 }, DType::Float32);
let r0 = create_range(10, 0);
let r1 = create_range(5, 1);
let r2 = create_range(20, 2);
let indexed = UOp::index().buffer(expanded).indices(vec![r0.clone(), r1.clone(), r2.clone()]).call().unwrap();
let pm = movement_op_patterns();
let result = graph_rewrite(&pm, indexed, &mut ());
assert!(matches!(result.op(), Op::Index { .. }), "Result should be INDEX");
let Op::Index { buffer: res_buf, indices: res_idx, .. } = result.op() else {
panic!("Expected INDEX");
};
assert_eq!(res_idx.len(), 1, "Should have 1 index after all movement ops transformed");
assert!(matches!(res_buf.op(), Op::Buffer { .. }), "Buffer should be the original buffer");
}
#[test]
fn test_permute_index_transformation() {
let buffer = create_buffer(10 * 20 * 30);
let reshaped =
buffer.try_reshape(&vec![SInt::Const(10), SInt::Const(20), SInt::Const(30)].into_iter().collect()).unwrap();
let permuted = reshaped.try_permute(vec![1, 2, 0]).unwrap();
let r0 = create_range(20, 0); let r1 = create_range(30, 1); let r2 = create_range(10, 2); let indexed = UOp::index().buffer(permuted).indices(vec![r0.clone(), r1.clone(), r2.clone()]).call().unwrap();
let pm = movement_op_patterns();
let result = graph_rewrite(&pm, indexed, &mut ());
assert!(matches!(result.op(), Op::Index { .. }));
let Op::Index { buffer: res_buf, indices: res_idx, .. } = result.op() else {
panic!("Expected INDEX");
};
assert_eq!(res_idx.len(), 1, "Should have 1 index after all movement ops transformed");
assert!(matches!(res_buf.op(), Op::Buffer { .. }), "Buffer should be the original buffer");
}
#[test]
fn test_reshape_index_transformation() {
let buffer = create_buffer(200);
let reshaped = buffer.try_reshape(&vec![SInt::Const(10), SInt::Const(20)].into_iter().collect()).unwrap();
let r0 = create_range(10, 0);
let r1 = create_range(20, 1);
let indexed = UOp::index().buffer(reshaped).indices(vec![r0, r1]).call().unwrap();
let pm = movement_op_patterns();
let result = graph_rewrite(&pm, indexed, &mut ());
assert!(matches!(result.op(), Op::Index { .. }));
let Op::Index { indices: res_idx, .. } = result.op() else {
panic!("Expected INDEX");
};
assert_eq!(res_idx.len(), 1, "Should flatten to 1D index");
}
#[test]
fn test_shrink_index_transformation() {
let buffer = create_buffer(10 * 40);
let reshaped = buffer.try_reshape(&vec![SInt::Const(10), SInt::Const(40)].into_iter().collect()).unwrap();
let begins = UOp::vectorize(vec![UOp::index_const(0), UOp::index_const(10)].into());
let ends = UOp::vectorize(vec![UOp::index_const(5), UOp::index_const(30)].into());
let shrunk = UOp::new(Op::Shrink { src: reshaped, begins, ends }, DType::Float32);
let r0 = create_range(5, 0);
let r1 = create_range(20, 1);
let indexed = UOp::index().buffer(shrunk).indices(vec![r0, r1]).call().unwrap();
let pm = movement_op_patterns();
let result = graph_rewrite(&pm, indexed, &mut ());
assert!(matches!(result.op(), Op::Index { .. }));
}
#[test]
fn test_flip_index_transformation() {
let buffer = create_buffer(10 * 20);
let reshaped = buffer.try_reshape(&vec![SInt::Const(10), SInt::Const(20)].into_iter().collect()).unwrap();
let flipped = UOp::new(Op::Flip { src: reshaped, axes: vec![false, true] }, DType::Float32);
let r0 = create_range(10, 0);
let r1 = create_range(20, 1);
let indexed = UOp::index().buffer(flipped).indices(vec![r0, r1]).call().unwrap();
let pm = movement_op_patterns();
let result = graph_rewrite(&pm, indexed, &mut ());
assert!(matches!(result.op(), Op::Index { .. }));
}
#[test]
fn test_pad_index_transformation() {
let buffer = create_buffer(10 * 20);
let reshaped = buffer.try_reshape(&vec![SInt::Const(10), SInt::Const(20)].into_iter().collect()).unwrap();
let begin_pads = UOp::vectorize(vec![UOp::index_const(1), UOp::index_const(2)].into());
let end_pads = UOp::vectorize(vec![UOp::index_const(1), UOp::index_const(2)].into());
let padded = UOp::new(Op::Pad { src: reshaped, begin_pads, end_pads }, DType::Float32);
let r0 = create_range(12, 0); let r1 = create_range(24, 1); let indexed = UOp::index().buffer(padded).indices(vec![r0, r1]).call().unwrap();
let pm = movement_op_patterns();
let result = graph_rewrite(&pm, indexed, &mut ());
assert!(matches!(result.op(), Op::Index { .. }));
}
#[test]
fn test_non_movement_op_no_match() {
let buffer = create_buffer(100);
let negated = buffer.try_sqrt().unwrap();
let r0 = create_range(100, 0);
let indexed = UOp::index().buffer(negated).indices(vec![r0]).call().unwrap();
let pm = movement_op_patterns();
let result = graph_rewrite(&pm, indexed, &mut ());
assert!(matches!(result.op(), Op::Index { .. }));
let Op::Index { buffer: res_buf, .. } = result.op() else {
panic!("Expected INDEX");
};
assert!(matches!(res_buf.op(), Op::Unary(..)), "Buffer should still be the SQRT");
}
#[test]
fn test_nested_movement_ops() {
#[allow(clippy::identity_op)]
let buffer = create_buffer(10 * 1);
let reshaped1 = buffer.try_reshape(&vec![SInt::Const(10), SInt::Const(1)].into_iter().collect()).unwrap();
let shape = UOp::vectorize(vec![UOp::index_const(10), UOp::index_const(5)].into());
let expanded = UOp::new(Op::Expand { src: reshaped1, new_shape: shape }, DType::Float32);
let reshaped2 = expanded.try_reshape(&vec![SInt::Const(50)].into_iter().collect()).unwrap();
let r0 = create_range(50, 0);
let indexed = UOp::index().buffer(reshaped2).indices(vec![r0]).call().unwrap();
let pm = movement_op_patterns();
let result = graph_rewrite(&pm, indexed, &mut ());
assert!(matches!(result.op(), Op::Index { .. }));
}