use morok_dtype::DType;
use morok_ir::types::ConstValue;
use morok_ir::{Op, UOp};
use super::helpers::*;
#[test]
fn test_bufferize_unroll_basic() {
let compute = create_unroll_iota(0, 4);
let range = create_unroll_iota(0, 4);
let bufferize = create_bufferize_global(compute, vec![range]);
let result = phase2_only(&bufferize);
match result.op() {
Op::Bufferize { compute: c, .. } => {
let has_raw_unroll = matches!(c.op(), Op::Unroll { .. });
assert!(
!has_raw_unroll
|| matches!(c.op(), Op::Contract { .. })
|| count_ops(&result, |u| matches!(u.op(), Op::Gep { .. })) > 0,
"UNROLL should be processed, got {:?}",
c.op()
);
}
_ => {
assert!(
count_unrolls(&result) == 0
|| count_contracts(&result) > 0
|| count_ops(&result, |u| matches!(u.op(), Op::Gep { .. })) > 0
);
}
}
}
#[test]
fn test_bufferize_unroll_multi_axis() {
let compute = create_unroll_multi_axis(vec![(0, 2), (1, 3)]);
let range = create_unroll_multi_axis(vec![(0, 2), (1, 3)]);
let bufferize = create_bufferize_global(compute, vec![range]);
let result = phase2_only(&bufferize);
if let Op::Bufferize { compute: c, .. } = result.op() {
assert!(c.dtype().vcount() >= 1, "Contracted compute should have valid vcount");
}
}
#[test]
fn test_bufferize_no_unroll_compute_passthrough() {
let compute = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let range = create_unroll_iota(0, 4);
let bufferize = create_bufferize_global(compute, vec![range]);
let result = phase2_only(&bufferize);
assert!(count_bufferizes(&result) > 0 || count_contracts(&result) == 0);
}
#[test]
fn test_bufferize_no_unroll_ranges_passthrough() {
let compute = create_unroll_iota(0, 4);
let range = UOp::const_(DType::Index, ConstValue::Int(0));
let bufferize = create_bufferize_global(compute, vec![range]);
let result = phase2_only(&bufferize);
assert!(count_bufferizes(&result) > 0 || count_ops(&result, |u| matches!(u.op(), Op::Gep { .. })) > 0);
}
#[test]
fn test_bufferize_multiple_ranges_passthrough() {
let compute = create_unroll_iota(0, 4);
let range1 = create_unroll_iota(0, 2);
let range2 = create_unroll_iota(1, 2);
let bufferize = create_bufferize_global(compute, vec![range1, range2]);
let result = phase2_only(&bufferize);
if let Op::Bufferize { ranges, .. } = result.op() {
assert!(!ranges.is_empty());
}
}
#[test]
fn test_bufferize_empty_unroll_axes_passthrough() {
let vconst = UOp::vconst(vec![ConstValue::Int(0)], DType::Int64);
let compute = vconst.clone().unroll(vec![]);
let range = vconst.unroll(vec![]);
let bufferize = create_bufferize_global(compute, vec![range]);
let result = phase2_only(&bufferize);
assert!(count_bufferizes(&result) > 0);
}
#[test]
fn test_bufferize_contract_dtype_matches() {
let compute = create_unroll_iota(0, 4);
let range = create_unroll_iota(0, 4);
let bufferize = create_bufferize_global(compute, vec![range]);
let result = phase2_only(&bufferize);
if let Op::Bufferize { compute: c, .. } = result.op() {
assert!(c.dtype().vcount() >= 1);
}
}
#[test]
fn test_bufferize_preserves_opts() {
use morok_dtype::AddrSpace;
use morok_ir::BufferizeOpts;
let compute = create_unroll_iota(0, 4);
let range = create_unroll_iota(0, 4);
let opts = BufferizeOpts::local();
let bufferize = create_bufferize(compute, vec![range], opts);
let result = phase2_only(&bufferize);
if let Op::Bufferize { opts: result_opts, .. } = result.op() {
assert_eq!(result_opts.addrspace, AddrSpace::Local, "BufferizeOpts should be preserved");
}
}
#[test]
fn test_bufferize_unroll_full_expander() {
let compute = create_unroll_iota(0, 4);
let range = create_unroll_iota(0, 4);
let bufferize = create_bufferize_global(compute, vec![range]);
let result = expander_rewrite(&bufferize);
let raw_unroll_count =
count_ops(&result, |u| matches!(u.op(), Op::Unroll { unroll_axes, .. } if !unroll_axes.is_empty()));
assert!(
raw_unroll_count == 0
|| count_contracts(&result) > 0
|| count_ops(&result, |u| matches!(u.op(), Op::Gep { .. })) > 0,
"UNROLLs should be expanded/contracted"
);
}