use std::sync::Arc;
use morok_ir::{AxisId, AxisType, DType, Op, UOp};
use crate::passes::pm_linearize_multi_index;
fn make_range(size: i64, axis_id: usize) -> Arc<UOp> {
let end = UOp::index_const(size);
UOp::range_axis(end, AxisId::Renumbered(axis_id), AxisType::Loop)
}
fn make_bufferize(dims: &[i64]) -> Arc<UOp> {
let compute = UOp::const_(DType::Float32, morok_ir::ConstValue::Float(0.0));
let ranges: Vec<Arc<UOp>> = dims.iter().enumerate().map(|(i, &size)| make_range(size, i)).collect();
UOp::bufferize_global(compute, ranges)
}
#[test]
fn test_linearize_pattern_2d() {
let buffer = make_bufferize(&[4, 8]);
let i = make_range(4, 0);
let j = make_range(8, 1);
let multi_index = UOp::index().buffer(buffer.clone()).indices(vec![i.clone(), j.clone()]).call().unwrap();
assert_eq!(multi_index.op().sources().len(), 3);
let result = crate::rewrite::graph_rewrite(pm_linearize_multi_index(), multi_index.clone(), &mut ());
if let Op::Index { indices, .. } = result.op() {
assert_eq!(indices.len(), 1, "Should have single linear index after linearization");
} else {
panic!("Expected INDEX op after linearization");
}
}
#[test]
fn test_linearize_pattern_3d() {
let buffer = make_bufferize(&[2, 3, 4]);
let i = make_range(2, 0);
let j = make_range(3, 1);
let k = make_range(4, 2);
let multi_index = UOp::index().buffer(buffer.clone()).indices(vec![i, j, k]).call().unwrap();
let result = crate::rewrite::graph_rewrite(pm_linearize_multi_index(), multi_index.clone(), &mut ());
if let Op::Index { indices, .. } = result.op() {
assert_eq!(indices.len(), 1, "3D index should be linearized to 1D");
} else {
panic!("Expected INDEX op");
}
}
#[test]
fn test_single_index_unchanged() {
let buffer = make_bufferize(&[10]);
let i = make_range(10, 0);
let single_index = UOp::index().buffer(buffer.clone()).indices(vec![i.clone()]).call().unwrap();
let result = crate::rewrite::graph_rewrite(pm_linearize_multi_index(), single_index.clone(), &mut ());
assert!(Arc::ptr_eq(&result, &single_index), "Single index should not be transformed");
}
#[test]
fn test_linearize_pattern_4d() {
let buffer = make_bufferize(&[2, 3, 4, 5]);
let i = make_range(2, 0);
let j = make_range(3, 1);
let k = make_range(4, 2);
let l = make_range(5, 3);
let multi_index = UOp::index().buffer(buffer.clone()).indices(vec![i, j, k, l]).call().unwrap();
let result = crate::rewrite::graph_rewrite(pm_linearize_multi_index(), multi_index.clone(), &mut ());
if let Op::Index { indices, .. } = result.op() {
assert_eq!(indices.len(), 1, "4D index should be linearized to 1D");
} else {
panic!("Expected INDEX op");
}
}
#[test]
fn test_unbounded_buffer_still_linearizes() {
let ptr_dtype = DType::Float32.ptr(None, morok_dtype::AddrSpace::Global);
let buffer = UOp::param(0, 1024, ptr_dtype, None);
let i = make_range(4, 0);
let j = make_range(8, 1);
let multi_index = UOp::index().buffer(buffer.clone()).indices(vec![i, j]).call().unwrap();
let result = crate::rewrite::graph_rewrite(pm_linearize_multi_index(), multi_index.clone(), &mut ());
if let Op::Index { indices, .. } = result.op() {
assert_eq!(indices.len(), 1, "Should have single linear index after linearization");
} else {
panic!("Expected INDEX op after linearization");
}
}