use morok_dtype::DType;
use morok_ir::{AxisId, AxisType, Op, UOp};
use crate::rangeify::{
IndexingContext,
transforms::{transform_single_source, transform_sources_with_bufferize},
};
#[test]
fn test_transform_buffer_source() {
let buffer1 = UOp::new_buffer(morok_device::DeviceSpec::Cpu, 40, DType::Float32);
let buffer2 = UOp::new_buffer(morok_device::DeviceSpec::Cpu, 40, DType::Float32);
let range = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(0), AxisType::Loop);
let consumer = buffer1.try_add(&buffer2).unwrap();
let mut ctx = IndexingContext::new();
ctx.set_ranges(&consumer, vec![range.clone()], vec![range.clone()]);
let new_sources = transform_sources_with_bufferize(&consumer, &mut ctx);
assert!(new_sources.is_some());
let new_sources = new_sources.unwrap();
assert_eq!(new_sources.len(), 2);
assert!(matches!(new_sources[0].op(), Op::Index { .. }));
assert!(matches!(new_sources[1].op(), Op::Index { .. }));
}
#[test]
fn test_transform_realizable_source() {
let a = UOp::native_const(1.0f32);
let x = a.try_add(&UOp::native_const(2.0f32)).unwrap();
let consumer = x.try_sqrt().unwrap();
let range = UOp::new(
Op::Range {
end: UOp::index_const(5),
axis_id: AxisId::Renumbered(0),
axis_type: AxisType::Loop,
deps: smallvec::SmallVec::new(),
},
DType::Index,
);
let mut ctx = IndexingContext::new();
ctx.set_ranges(&x, vec![range.clone()], vec![range.clone()]);
ctx.set_ranges(&consumer, vec![range.clone()], vec![range.clone()]);
ctx.mark_realize(&x, vec![0]);
let new_src = transform_single_source(&consumer, &x, std::slice::from_ref(&range), &mut ctx);
if let Op::Index { buffer, .. } = new_src.op() {
assert!(matches!(buffer.op(), Op::Bufferize { .. }));
} else {
panic!("Expected INDEX operation");
}
}
#[test]
fn test_no_transform_for_normal_source() {
let x = UOp::native_const(1.0f32);
let y = UOp::native_const(2.0f32);
let add = x.try_add(&y).unwrap();
let mut ctx = IndexingContext::new();
let result = transform_sources_with_bufferize(&add, &mut ctx);
assert!(result.is_none());
}
#[test]
fn test_transform_movement_chain_on_buffer() {
let buffer = UOp::new_buffer(morok_device::DeviceSpec::Cpu, 12, DType::Float32);
let reshape_shape = UOp::vectorize(vec![UOp::index_const(3), UOp::index_const(4)].into());
let reshape = UOp::new(Op::Reshape { src: buffer.clone(), new_shape: reshape_shape }, DType::Float32);
assert!(reshape.op().is_movement(), "RESHAPE should be identified as movement op");
let buffer2 = UOp::new_buffer(morok_device::DeviceSpec::Cpu, 12, DType::Float32);
let reshape_shape2 = UOp::vectorize(vec![UOp::index_const(3), UOp::index_const(4)].into());
let reshape2 = UOp::new(Op::Reshape { src: buffer2.clone(), new_shape: reshape_shape2 }, DType::Float32);
let add = reshape.try_add(&reshape2).unwrap();
let range0 = UOp::range_axis(UOp::index_const(3), AxisId::Renumbered(0), AxisType::Loop);
let range1 = UOp::range_axis(UOp::index_const(4), AxisId::Renumbered(1), AxisType::Loop);
let mut ctx = IndexingContext::new();
ctx.set_ranges(&add, vec![range0.clone(), range1.clone()], vec![range0.clone(), range1.clone()]);
let new_sources = transform_sources_with_bufferize(&add, &mut ctx);
assert!(new_sources.is_none(), "Movement ops should be left for BPM rewrite engine");
}
#[test]
fn test_rangeify_with_symbolic_simplification() {
let src = UOp::new_buffer(morok_device::DeviceSpec::Cpu, 6, DType::Float32);
let reshaped = src.try_reshape(&smallvec::smallvec![morok_ir::SInt::Const(2), morok_ir::SInt::Const(3)]).unwrap();
let permute = reshaped.try_permute(vec![1, 0]).unwrap();
let (result, _ctx) = crate::rangeify::rangeify(permute, None).unwrap();
assert!(result.dtype() == DType::Float32);
}