use std::{f32::consts::PI, sync::Arc};
use crate::rangeify::patterns::buffer_folding;
use crate::rewrite::graph_rewrite;
use morok_dtype::DType;
use morok_ir::{ConstValue, UOp};
fn create_const(val: i64) -> Arc<UOp> {
UOp::native_const(val as i32)
}
fn create_range(end: i64, axis_id: usize) -> Arc<UOp> {
UOp::range_const(end, axis_id)
}
fn create_bufferize(compute: Arc<UOp>, ranges: Vec<Arc<UOp>>) -> Arc<UOp> {
UOp::bufferize_global(compute, ranges)
}
#[test]
fn test_noop_bufferize_same_ranges() {
let x = UOp::param(1, 10, DType::Float32, None);
let range = create_range(10, 0);
let ranges = vec![range.clone()];
let bufferized = create_bufferize(x.clone(), ranges.clone());
let indexed = UOp::index().buffer(bufferized).indices(ranges).call().unwrap();
let matcher = buffer_folding();
let result = graph_rewrite(&matcher, indexed, &mut ());
assert!(Arc::ptr_eq(&result, &x), "Noop BUFFERIZE should be removed");
}
#[test]
fn test_noop_bufferize_different_ranges() {
let x = UOp::param(1, 1024, DType::Float32, None);
let range1 = create_range(10, 0);
let range2 = create_range(10, 1);
let bufferized = create_bufferize(x.clone(), vec![range1]);
let indexed = UOp::index().buffer(bufferized.clone()).indices(vec![range2]).call().unwrap();
let matcher = buffer_folding();
let result = graph_rewrite(&matcher, indexed.clone(), &mut ());
assert!(Arc::ptr_eq(&result, &indexed), "Should not fold with different ranges");
}
#[test]
fn test_noop_bufferize_multiple_ranges() {
let x = UOp::param(1, 200, DType::Float32, None);
let range1 = create_range(10, 0);
let range2 = create_range(20, 1);
let ranges = vec![range1.clone(), range2.clone()];
let bufferized = create_bufferize(x.clone(), ranges.clone());
let indexed = UOp::index().buffer(bufferized).indices(ranges).call().unwrap();
let matcher = buffer_folding();
let result = graph_rewrite(&matcher, indexed, &mut ());
assert!(
!matches!(result.op(), morok_ir::Op::Bufferize { .. }),
"Noop BUFFERIZE with multiple ranges should be removed"
);
}
#[test]
fn test_bufferize_const_folding() {
let const_val = create_const(42);
let range = create_range(10, 0);
let bufferized = create_bufferize(const_val.clone(), vec![range]);
let matcher = buffer_folding();
let result = graph_rewrite(&matcher, bufferized, &mut ());
assert!(Arc::ptr_eq(&result, &const_val), "BUFFERIZE(CONST) should fold to CONST");
}
#[test]
fn test_bufferize_different_const_types() {
let test_cases = vec![
(DType::Int32, ConstValue::Int(100)),
(DType::Float32, ConstValue::Float(PI.into())),
(DType::Bool, ConstValue::Bool(true)),
];
for (dtype, val) in test_cases {
let const_val = UOp::const_(dtype.clone(), val);
let range = create_range(5, 0);
let bufferized = create_bufferize(const_val.clone(), vec![range]);
let matcher = buffer_folding();
let result = graph_rewrite(&matcher, bufferized, &mut ());
assert!(Arc::ptr_eq(&result, &const_val), "BUFFERIZE(CONST) should fold for {:?}", dtype);
}
}
#[test]
fn test_index_const_folding() {
let const_val = create_const(7);
let range = create_range(10, 0);
let indexed = UOp::index().buffer(const_val.clone()).indices(vec![range]).call().unwrap();
let matcher = buffer_folding();
let result = graph_rewrite(&matcher, indexed, &mut ());
assert!(Arc::ptr_eq(&result, &const_val), "INDEX(CONST) should fold to CONST");
}
#[test]
fn test_index_const_multiple_indices() {
let const_val = UOp::native_const(2.5f32);
let ranges = vec![create_range(10, 0), create_range(20, 1), create_range(30, 2)];
let indexed = UOp::index().buffer(const_val.clone()).indices(ranges).call().unwrap();
let matcher = buffer_folding();
let result = graph_rewrite(&matcher, indexed, &mut ());
assert!(Arc::ptr_eq(&result, &const_val), "INDEX(CONST) with multiple indices should fold");
}
#[test]
fn test_copy_const_folding() {
let const_val = create_const(99);
let device = UOp::device(morok_device::DeviceSpec::Cpu);
let copy = const_val.copy(device);
let matcher = buffer_folding();
let result = graph_rewrite(&matcher, copy, &mut ());
assert!(Arc::ptr_eq(&result, &const_val), "COPY(CONST) should fold to CONST");
}
#[test]
fn test_copy_const_different_devices() {
let const_val = UOp::native_const(1.5f32);
let devices = vec![morok_device::DeviceSpec::Cpu, morok_device::DeviceSpec::Cuda { device_id: 0 }];
for device_spec in devices {
let device = UOp::device(device_spec);
let copy = const_val.copy(device);
let matcher = buffer_folding();
let result = graph_rewrite(&matcher, copy, &mut ());
assert!(Arc::ptr_eq(&result, &const_val), "COPY(CONST) should fold regardless of device");
}
}
#[test]
fn test_nested_constant_folding() {
let const_val = create_const(123);
let range = create_range(15, 0);
let bufferized = create_bufferize(const_val.clone(), vec![range.clone()]);
let indexed = UOp::index().buffer(bufferized).indices(vec![range]).call().unwrap();
let matcher = buffer_folding();
let result = graph_rewrite(&matcher, indexed, &mut ());
assert!(Arc::ptr_eq(&result, &const_val), "Nested constant operations should fold completely");
}
#[test]
fn test_noop_fold_non_const_operations() {
let x = UOp::var("x", DType::Float32, 0, 100);
let y = UOp::var("y", DType::Float32, 0, 100);
let add = x.try_add(&y).unwrap();
let range = create_range(10, 0);
let bufferized = create_bufferize(add.clone(), vec![range.clone()]);
let indexed = UOp::index().buffer(bufferized.clone()).indices(vec![range]).call().unwrap();
let matcher = buffer_folding();
let result = graph_rewrite(&matcher, indexed.clone(), &mut ());
assert!(Arc::ptr_eq(&result, &add), "Noop BUFFERIZE+INDEX should fold regardless of operation type");
}