use std::sync::Arc;
use morok_dtype::DType;
use morok_ir::{AddrSpace, AxisId, AxisType, BufferizeOpts, Op, SInt, UOp};
use test_case::test_case;
use crate::rangeify::indexing::IndexingContext;
use crate::rangeify::kernel::PcontigConfig;
use crate::rangeify::patterns::buffer_removal_with_pcontig;
use crate::rewrite::graph_rewrite;
fn create_test_buffer(size: usize, dtype: DType, id: usize) -> Arc<UOp> {
let device = UOp::device(morok_device::DeviceSpec::Cpu);
UOp::new(Op::Param { slot: id, size, device: Some(device) }, dtype.ptr(Some(size), AddrSpace::Global))
}
fn create_index_bufferize(
src: Arc<UOp>,
buf_ranges: Vec<Arc<UOp>>,
idx_ranges: Vec<Arc<UOp>>,
opts: BufferizeOpts,
) -> Arc<UOp> {
let bufferized = UOp::bufferize(src, buf_ranges, opts);
UOp::index().buffer(bufferized).indices(idx_ranges).call().expect("Failed to create INDEX")
}
fn create_simple_graph(ctx: &mut IndexingContext) -> (Arc<UOp>, Arc<UOp>, Arc<UOp>, Arc<UOp>) {
let buffer = create_test_buffer(100, DType::Float32, 1);
let range1 = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let range2 = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let indexed = UOp::index()
.buffer(buffer.clone())
.indices(vec![range1.clone(), range2.clone()])
.call()
.expect("Failed to create INDEX");
let one = UOp::native_const(1.0f32);
let compute = indexed.try_add(&one).expect("Failed to create ADD");
(buffer, range1, range2, compute)
}
fn create_multi_buffer_graph(
ctx: &mut IndexingContext,
num_buffers: usize,
) -> (Vec<Arc<UOp>>, Vec<Arc<UOp>>, Arc<UOp>) {
assert!(num_buffers > 0, "Must have at least one buffer");
let mut buffers = Vec::new();
let range1 = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let range2 = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let ranges = vec![range1.clone(), range2.clone()];
let mut compute = {
let buffer = create_test_buffer(100, DType::Float32, 0);
let indexed =
UOp::index().buffer(buffer.clone()).indices(ranges.clone()).call().expect("Failed to create INDEX");
buffers.push(buffer);
indexed
};
for i in 1..num_buffers {
let buffer = create_test_buffer(100, DType::Float32, i);
let indexed =
UOp::index().buffer(buffer.clone()).indices(ranges.clone()).call().expect("Failed to create INDEX");
compute = compute.try_add(&indexed).expect("Failed to create ADD");
buffers.push(buffer);
}
(buffers, ranges, compute)
}
fn create_buffer_with_size(size: usize, dtype: DType, id: usize) -> Arc<UOp> {
create_test_buffer(size, dtype, id)
}
fn create_ratio_test_graph(
ctx: &mut IndexingContext,
input_size: usize,
output_size: usize,
) -> (Arc<UOp>, usize, Vec<Arc<UOp>>, Arc<UOp>) {
let input_buffer = create_buffer_with_size(input_size, DType::Float32, 1);
let elements = output_size / 4;
let range_size = (elements as f64).sqrt() as usize;
let range1 = ctx.new_range(&SInt::Const(range_size), AxisType::Loop);
let range2 = ctx.new_range(&SInt::Const(range_size), AxisType::Loop);
let ranges = vec![range1, range2];
let compute =
UOp::index().buffer(input_buffer.clone()).indices(ranges.clone()).call().expect("Failed to create INDEX");
(input_buffer, output_size, ranges, compute)
}
fn create_reduce_graph(ctx: &mut IndexingContext, has_buffer_access: bool) -> (Vec<Arc<UOp>>, Arc<UOp>) {
let loop_range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let reduce_range = ctx.new_range(&SInt::Const(20), AxisType::Reduce);
let all_ranges = vec![loop_range.clone(), reduce_range.clone()];
let compute = if has_buffer_access {
let buffer = create_test_buffer(800, DType::Float32, 1); let indexed = UOp::index().buffer(buffer).indices(all_ranges.clone()).call().expect("Failed to create INDEX");
UOp::new(
Op::Reduce { src: indexed, ranges: vec![reduce_range].into(), reduce_op: morok_ir::ReduceOp::Add },
DType::Float32,
)
} else {
let const_val = UOp::native_const(1.0f32);
UOp::new(
Op::Reduce { src: const_val, ranges: vec![reduce_range].into(), reduce_op: morok_ir::ReduceOp::Add },
DType::Float32,
)
};
(all_ranges, compute)
}
#[test]
fn test_config_default() {
let config = PcontigConfig::default();
assert_eq!(config.level, 2); assert_eq!(config.max_buffers_threshold, 3);
assert_eq!(config.out_in_ratio_threshold, 10.0);
}
#[test]
fn test_config_levels() {
let disabled = PcontigConfig { level: 0, ..Default::default() };
let basic = PcontigConfig { level: 1, ..Default::default() };
let enabled = PcontigConfig { level: 2, ..Default::default() };
let aggressive = PcontigConfig { level: 3, ..Default::default() };
assert_eq!(disabled.level, 0);
assert_eq!(basic.level, 1);
assert_eq!(enabled.level, 2);
assert_eq!(aggressive.level, 3);
}
#[test]
fn test_pattern_matcher_creation() {
let matcher = buffer_removal_with_pcontig();
drop(matcher);
}
#[test]
fn test_disabled_config_no_rewrite() {
let mut config = PcontigConfig { level: 0, ..Default::default() };
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let (_buffer, range1, range2, compute) = create_simple_graph(&mut ctx);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, vec![range1.clone(), range2.clone()], vec![range1, range2], opts);
let rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut config);
assert!(Arc::ptr_eq(&rewritten, &idx_buf), "Expected no rewrite with level=0");
}
#[test]
fn test_cheap_inline_removal() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let const_val = UOp::native_const(1.0f32);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(const_val.clone(), vec![range], opts);
let rewritten = graph_rewrite(&matcher, bufferized, &mut config);
assert!(matches!(rewritten.op(), Op::Const(_)), "Expected const to be inlined");
}
#[test]
fn test_nested_bufferize_removal() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let const_val = UOp::native_const(1.0f32);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let inner = UOp::bufferize(const_val.clone(), vec![range.clone()], opts.clone());
let outer = UOp::bufferize(inner, vec![range.clone()], opts);
let rewritten = graph_rewrite(&matcher, outer, &mut config);
let final_result = graph_rewrite(&matcher, rewritten, &mut config);
assert!(matches!(final_result.op(), Op::Const(_)), "Expected const to be fully inlined");
}
#[test]
fn test_simple_index_bufferize_pattern() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let (_buffer, range1, range2, compute) = create_simple_graph(&mut ctx);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, vec![range1.clone(), range2.clone()], vec![range1, range2], opts);
let rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut config);
drop(rewritten);
}
#[test_case(1 ; "one buffer - should optimize")]
#[test_case(2 ; "two buffers - should optimize")]
#[test_case(3 ; "three buffers - at threshold should optimize")]
#[test_case(4 ; "four buffers - above threshold should keep")]
#[test_case(5 ; "five buffers - above threshold should keep")]
fn test_accessed_buffers_threshold(num_buffers: usize) {
let mut config = PcontigConfig { level: 0, ..PcontigConfig::default() };
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let (_buffers, ranges, compute) = create_multi_buffer_graph(&mut ctx, num_buffers);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, ranges.clone(), ranges, opts);
let rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut config);
let threshold = config.max_buffers_threshold;
if num_buffers > threshold {
assert!(
Arc::ptr_eq(&rewritten, &idx_buf),
"Expected no rewrite with {} buffers (>{} threshold)",
num_buffers,
threshold
);
} else {
drop(rewritten);
}
}
#[test]
fn test_accessed_buffers_with_duplicates() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(100, DType::Float32, 1);
let range1 = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let range2 = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let ranges = vec![range1.clone(), range2.clone()];
let idx1 = UOp::index().buffer(buffer.clone()).indices(ranges.clone()).call().expect("Failed to create INDEX");
let idx2 = UOp::index().buffer(buffer.clone()).indices(ranges.clone()).call().expect("Failed to create INDEX");
let idx3 = UOp::index().buffer(buffer).indices(ranges.clone()).call().expect("Failed to create INDEX");
let compute = idx1.try_add(&idx2).expect("Failed to create ADD").try_add(&idx3).expect("Failed to create ADD");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, ranges.clone(), ranges, opts);
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
drop(rewritten);
}
#[test]
fn test_accessed_buffers_nested_computation() {
let mut config = PcontigConfig { level: 0, ..PcontigConfig::default() };
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let ranges = vec![ctx.new_range(&SInt::Const(10), AxisType::Loop), ctx.new_range(&SInt::Const(10), AxisType::Loop)];
let buf1 = create_test_buffer(100, DType::Float32, 1);
let buf2 = create_test_buffer(100, DType::Float32, 2);
let buf3 = create_test_buffer(100, DType::Float32, 3);
let buf4 = create_test_buffer(100, DType::Float32, 4);
let idx1 = UOp::index().buffer(buf1).indices(ranges.clone()).call().expect("Failed to create INDEX");
let idx2 = UOp::index().buffer(buf2).indices(ranges.clone()).call().expect("Failed to create INDEX");
let idx3 = UOp::index().buffer(buf3).indices(ranges.clone()).call().expect("Failed to create INDEX");
let idx4 = UOp::index().buffer(buf4).indices(ranges.clone()).call().expect("Failed to create INDEX");
let left = idx1.try_add(&idx2).expect("Failed to create ADD");
let right = idx3.try_add(&idx4).expect("Failed to create ADD");
let compute = left.try_mul(&right).expect("Failed to create MUL");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, ranges.clone(), ranges, opts);
let rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut config);
assert!(Arc::ptr_eq(&rewritten, &idx_buf), "Expected no rewrite with 4 buffers in nested computation");
}
#[test]
fn test_out_in_ratio_efficient_buffer() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let input_size = 1000;
let output_size = 9000;
let (_input_buffer, _output_size, ranges, compute) = create_ratio_test_graph(&mut ctx, input_size, output_size);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, ranges.clone(), ranges, opts);
let rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut config);
assert!(!Arc::ptr_eq(&rewritten, &idx_buf), "Simple path should always inline (no ratio guard, matching Tinygrad)");
assert_eq!(count_bufferizes(&rewritten), 0, "BUFFERIZE should be removed after inlining");
}
#[test]
fn test_out_in_ratio_at_threshold() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let input_size = 1000;
let output_size = 10000;
let (_input_buffer, _output_size, ranges, compute) = create_ratio_test_graph(&mut ctx, input_size, output_size);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, ranges.clone(), ranges, opts);
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
drop(rewritten);
}
#[test]
fn test_out_in_ratio_wasteful_buffer() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let input_size = 1000;
let output_size = 50000;
let (_input_buffer, _output_size, ranges, compute) = create_ratio_test_graph(&mut ctx, input_size, output_size);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, ranges.clone(), ranges, opts);
let rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut config);
drop(rewritten);
}
#[test]
fn test_out_in_ratio_flash_attention_simulation() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let input_size = 100_000_000;
let input_buffer = create_buffer_with_size(input_size, DType::Float32, 1);
let range1 = ctx.new_range(&SInt::Const(16), AxisType::Loop);
let range2 = ctx.new_range(&SInt::Const(32), AxisType::Loop);
let ranges = vec![range1, range2];
let compute = UOp::index().buffer(input_buffer).indices(ranges.clone()).call().expect("Failed to create INDEX");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, ranges.clone(), ranges, opts);
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
drop(rewritten);
}
#[test]
fn test_out_in_ratio_symbolic_sizes() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let n = UOp::param(1, 1, DType::Index, None);
let buffer = UOp::param(2, 1, DType::Float32, None);
let range = UOp::range(n, 0);
let ranges = vec![range.clone()];
let compute = UOp::index().buffer(buffer).indices(ranges.clone()).call().expect("Failed to create INDEX");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, ranges.clone(), ranges, opts);
let rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut config);
drop(rewritten);
}
#[test]
fn test_out_in_ratio_no_inputs() {
use crate::rewrite::graph_rewrite_bottom_up;
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let ranges = vec![range];
let const_val = UOp::native_const(1.0f32);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(const_val, ranges.clone(), ranges, opts);
let rewritten = graph_rewrite_bottom_up(&matcher, idx_buf, &mut config);
assert!(matches!(rewritten.op(), Op::Const(_)), "Expected constant to be inlined");
}
#[test]
fn test_buffer_not_in_reduce_full_removal() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(100, DType::Float32, 1);
let range1 = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let range2 = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let ranges = vec![range1.clone(), range2.clone()];
let indexed = UOp::index().buffer(buffer).indices(ranges.clone()).call().expect("Failed to create INDEX");
let one = UOp::native_const(1.0f32);
let compute = indexed.try_add(&one).expect("Failed to create ADD");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, ranges.clone(), ranges, opts);
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
drop(rewritten);
}
#[test]
fn test_buffer_in_reduce_partial_contiguous() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let (ranges, compute) = create_reduce_graph(&mut ctx, true);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(compute, ranges.clone(), opts.clone());
let rewritten = graph_rewrite(&matcher, bufferized, &mut config);
drop(rewritten);
}
#[test]
fn test_reduce_without_buffer_access_full_removal() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let (ranges, reduce_compute) = create_reduce_graph(&mut ctx, false);
let buffer = create_test_buffer(100, DType::Float32, 2);
let buffer_indexed = UOp::index().buffer(buffer).indices(ranges.clone()).call().expect("Failed to create INDEX");
let compute = buffer_indexed.try_add(&reduce_compute).expect("Failed to create ADD");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, ranges.clone(), ranges, opts);
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
drop(rewritten);
}
#[test]
fn test_multiple_reduces_with_buffer() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(1600, DType::Float32, 1); let loop_range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let reduce_range1 = ctx.new_range(&SInt::Const(10), AxisType::Reduce);
let reduce_range2 = ctx.new_range(&SInt::Const(4), AxisType::Reduce);
let ranges = vec![loop_range.clone(), reduce_range1.clone(), reduce_range2.clone()];
let indexed = UOp::index().buffer(buffer).indices(ranges.clone()).call().expect("Failed to create INDEX");
let reduce1 = UOp::new(
Op::Reduce { src: indexed, ranges: vec![reduce_range1].into(), reduce_op: morok_ir::ReduceOp::Add },
DType::Float32,
);
let reduce2 = UOp::new(
Op::Reduce { src: reduce1, ranges: vec![reduce_range2].into(), reduce_op: morok_ir::ReduceOp::Add },
DType::Float32,
);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(reduce2, ranges, opts);
let rewritten = graph_rewrite(&matcher, bufferized, &mut config);
drop(rewritten);
}
#[test]
fn test_nested_reduce_with_buffer() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(200, DType::Float32, 1);
let outer_reduce_range = ctx.new_range(&SInt::Const(10), AxisType::Reduce);
let ranges = vec![outer_reduce_range.clone()];
let indexed = UOp::index().buffer(buffer).indices(ranges.clone()).call().expect("Failed to create INDEX");
let inner_reduce_range = ctx.new_range(&SInt::Const(5), AxisType::Reduce);
let inner_reduce = UOp::new(
Op::Reduce {
src: indexed.clone(),
ranges: vec![inner_reduce_range].into(),
reduce_op: morok_ir::ReduceOp::Add,
},
DType::Float32,
);
let outer_reduce = UOp::new(
Op::Reduce { src: inner_reduce, ranges: vec![outer_reduce_range].into(), reduce_op: morok_ir::ReduceOp::Add },
DType::Float32,
);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(outer_reduce, ranges, opts);
let rewritten = graph_rewrite(&matcher, bufferized, &mut config);
drop(rewritten);
}
#[allow(clippy::mutable_key_type)] fn count_bufferizes(uop: &Arc<UOp>) -> usize {
let mut count = 0;
let mut stack = vec![uop.clone()];
let mut visited = std::collections::HashSet::new();
while let Some(current) = stack.pop() {
if !visited.insert(morok_ir::UOpKey(current.clone())) {
continue;
}
if matches!(current.op(), Op::Bufferize { .. }) {
count += 1;
}
for src in current.op().sources() {
stack.push(src.clone());
}
}
count
}
#[test]
fn test_pattern1_cheap_inline() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let const_val = UOp::native_const(42.0f32);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(const_val.clone(), vec![range], opts);
let rewritten = graph_rewrite(&matcher, bufferized, &mut config);
assert!(matches!(rewritten.op(), Op::Const(_)), "Pattern 1 should inline const");
let count = count_bufferizes(&rewritten);
assert_eq!(count, 0, "No BUFFERIZE should remain after Pattern 1");
}
#[test]
fn test_pattern4_full_removal_with_permissive_config() {
let mut config = PcontigConfig {
level: 2,
max_buffers_threshold: 10, out_in_ratio_threshold: 1.0, };
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(40, DType::Float32, 1);
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let indexed = UOp::index().buffer(buffer).indices(vec![range.clone()]).call().expect("Failed to create INDEX");
let compute = indexed;
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, vec![range.clone()], vec![range], opts);
let bufferizes_before = count_bufferizes(&idx_buf);
let rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut config);
let bufferizes_after = count_bufferizes(&rewritten);
assert!(
bufferizes_after < bufferizes_before,
"Expected BUFFERIZE removal with permissive config, before={}, after={}",
bufferizes_before,
bufferizes_after
);
}
#[test]
fn test_pattern4_keeps_efficient_buffer() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let input_buffer = create_test_buffer(40, DType::Float32, 1);
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let indexed =
UOp::index().buffer(input_buffer).indices(vec![range.clone()]).call().expect("Failed to create INDEX");
let compute = indexed;
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, vec![range.clone()], vec![range], opts);
let rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut config);
assert!(!Arc::ptr_eq(&rewritten, &idx_buf), "Simple path should always inline (matching Tinygrad)");
}
#[test]
fn test_pattern1_preserves_dtype() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let const_val = UOp::native_const(42.0f32);
let original_dtype = const_val.dtype();
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(const_val, vec![range], opts);
let rewritten = graph_rewrite(&matcher, bufferized, &mut config);
assert_eq!(rewritten.dtype(), original_dtype, "Dtype should be preserved after Pattern 1");
}
#[test]
fn test_full_removal_blocked_by_heuristics() {
let mut config = PcontigConfig { level: 0, ..PcontigConfig::default() };
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let (_buffers, ranges, compute) = create_multi_buffer_graph(&mut ctx, 4);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, ranges.clone(), ranges, opts);
let rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut config);
assert!(Arc::ptr_eq(&rewritten, &idx_buf), "Expected no rewrite when heuristics prevent optimization");
}
#[test]
fn test_partial_contiguous_single_reduce() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(4000, DType::Float32, 1);
let loop_range = ctx.new_range(&SInt::Const(100), AxisType::Loop);
let reduce_range = ctx.new_range(&SInt::Const(10), AxisType::Reduce);
let all_ranges = vec![loop_range.clone(), reduce_range.clone()];
let indexed = UOp::index().buffer(buffer).indices(all_ranges.clone()).call().expect("Failed to create INDEX");
let reduce = UOp::new(
Op::Reduce { src: indexed, ranges: vec![reduce_range.clone()].into(), reduce_op: morok_ir::ReduceOp::Add },
DType::Float32,
);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(reduce.clone(), vec![loop_range.clone()], opts);
let idx_buf = UOp::index().buffer(bufferized).indices(vec![loop_range]).call().expect("Failed to create INDEX");
let rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut config);
drop(rewritten);
}
#[test]
fn test_partial_contiguous_local_axis() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(400, DType::Float32, 1);
let loop_range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let local_range = ctx.new_range(&SInt::Const(10), AxisType::Local);
let all_ranges = vec![loop_range.clone(), local_range.clone()];
let indexed = UOp::index().buffer(buffer).indices(all_ranges.clone()).call().expect("Failed to create INDEX");
let two = UOp::native_const(2.0f32);
let compute = indexed.try_mul(&two).expect("Failed to create MUL");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, all_ranges.clone(), all_ranges, opts);
let rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut config);
drop(rewritten);
}
#[test]
fn test_partial_contiguous_mixed_axes() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(8000, DType::Float32, 1);
let loop_range = ctx.new_range(&SInt::Const(100), AxisType::Loop);
let reduce_range = ctx.new_range(&SInt::Const(20), AxisType::Reduce);
let all_ranges = vec![loop_range.clone(), reduce_range.clone()];
let indexed = UOp::index().buffer(buffer).indices(all_ranges.clone()).call().expect("Failed to create INDEX");
let reduce = UOp::new(
Op::Reduce { src: indexed, ranges: vec![reduce_range].into(), reduce_op: morok_ir::ReduceOp::Max },
DType::Float32,
);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(reduce, vec![loop_range.clone()], opts);
let idx_buf = UOp::index().buffer(bufferized).indices(vec![loop_range]).call().expect("Failed to create INDEX");
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
drop(rewritten);
}
#[test]
fn test_partial_contiguous_different_reduce_ops() {
use morok_ir::ReduceOp;
for reduce_op in [ReduceOp::Add, ReduceOp::Max, ReduceOp::Mul] {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(400, DType::Float32, 1);
let loop_range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let reduce_range = ctx.new_range(&SInt::Const(10), AxisType::Reduce);
let indexed = UOp::index()
.buffer(buffer)
.indices(vec![loop_range.clone(), reduce_range.clone()])
.call()
.expect("Failed to create INDEX");
let reduce = indexed.reduce(smallvec::smallvec![reduce_range], reduce_op);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(reduce, vec![loop_range.clone()], opts);
let idx_buf = UOp::index().buffer(bufferized).indices(vec![loop_range]).call().expect("Failed to create INDEX");
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
drop(rewritten);
}
}
#[test]
fn test_partial_contiguous_multi_dimensional_reduce() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(8000, DType::Float32, 1);
let loop_range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let reduce_range1 = ctx.new_range(&SInt::Const(20), AxisType::Reduce);
let reduce_range2 = ctx.new_range(&SInt::Const(10), AxisType::Reduce);
let indexed = UOp::index()
.buffer(buffer)
.indices(vec![loop_range.clone(), reduce_range1.clone(), reduce_range2.clone()])
.call()
.expect("Failed to create INDEX");
let reduce = UOp::new(
Op::Reduce {
src: indexed,
ranges: vec![reduce_range1, reduce_range2].into(),
reduce_op: morok_ir::ReduceOp::Add,
},
DType::Float32,
);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(reduce, vec![loop_range.clone()], opts);
let idx_buf = UOp::index().buffer(bufferized).indices(vec![loop_range]).call().expect("Failed to create INDEX");
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
drop(rewritten);
}
#[test]
fn test_partial_contiguous_blocked_by_heuristics() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buf1 = create_test_buffer(40, DType::Float32, 1);
let buf2 = create_test_buffer(40, DType::Float32, 2);
let buf3 = create_test_buffer(40, DType::Float32, 3);
let buf4 = create_test_buffer(40, DType::Float32, 4);
let loop_range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let reduce_range = ctx.new_range(&SInt::Const(10), AxisType::Reduce);
let idx1 = UOp::index().buffer(buf1).indices(vec![loop_range.clone()]).call().expect("Failed to create INDEX");
let idx2 = UOp::index().buffer(buf2).indices(vec![loop_range.clone()]).call().expect("Failed to create INDEX");
let idx3 = UOp::index().buffer(buf3).indices(vec![loop_range.clone()]).call().expect("Failed to create INDEX");
let idx4 = UOp::index().buffer(buf4).indices(vec![loop_range.clone()]).call().expect("Failed to create INDEX");
let add1 = idx1.try_add(&idx2).expect("Failed to create ADD");
let add2 = idx3.try_add(&idx4).expect("Failed to create ADD");
let combined = add1.try_add(&add2).expect("Failed to create ADD");
let reduce = UOp::new(
Op::Reduce { src: combined, ranges: vec![reduce_range].into(), reduce_op: morok_ir::ReduceOp::Add },
DType::Float32,
);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(reduce, vec![loop_range.clone()], opts);
let idx_buf = UOp::index().buffer(bufferized).indices(vec![loop_range]).call().expect("Failed to create INDEX");
let rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut config);
assert!(Arc::ptr_eq(&rewritten, &idx_buf), "Expected no rewrite when accessed_buffers > threshold");
}
#[test]
fn test_edge_case_empty_computation() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(40, DType::Float32, 1);
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let indexed = UOp::index().buffer(buffer).indices(vec![range.clone()]).call().expect("Failed to create INDEX");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(indexed, vec![range.clone()], vec![range], opts);
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
drop(rewritten);
}
#[test]
fn test_edge_case_all_const_operations() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let const1 = UOp::native_const(1.0f32);
let const2 = UOp::native_const(2.0f32);
let compute = const1.try_add(&const2).expect("Failed to create ADD");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(compute.clone(), vec![range], opts);
let rewritten = graph_rewrite(&matcher, bufferized.clone(), &mut config);
assert!(Arc::ptr_eq(&rewritten, &bufferized), "Bare BUFFERIZE(non-const compute) should not be removed");
}
#[test]
fn test_edge_case_deeply_nested_bufferize() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let range = ctx.new_range(&SInt::Const(5), AxisType::Loop);
let const_val = UOp::native_const(42.0f32);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let level1 = UOp::bufferize(const_val, vec![range.clone()], opts.clone());
let level2 = UOp::bufferize(level1, vec![range.clone()], opts.clone());
let level3 = UOp::bufferize(level2, vec![range], opts);
let rewritten1 = graph_rewrite(&matcher, level3, &mut config);
let rewritten2 = graph_rewrite(&matcher, rewritten1, &mut config);
drop(rewritten2);
}
#[test]
fn test_edge_case_zero_sized_range() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(40, DType::Float32, 1);
let zero_range = ctx.new_range(&SInt::Const(0), AxisType::Loop);
let indexed = UOp::index().buffer(buffer).indices(vec![zero_range.clone()]).call().expect("Failed to create INDEX");
let two = UOp::native_const(2.0f32);
let compute = indexed.try_mul(&two).expect("Failed to create MUL");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, vec![zero_range.clone()], vec![zero_range], opts);
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
drop(rewritten);
}
#[test]
fn test_config_custom_max_buffers_threshold() {
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let (_buffers, ranges, compute) = create_multi_buffer_graph(&mut ctx, 4);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, ranges.clone(), ranges, opts);
let mut default_config = PcontigConfig { level: 0, ..PcontigConfig::default() };
let default_rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut default_config);
assert!(Arc::ptr_eq(&default_rewritten, &idx_buf), "Default config should block 4 buffers");
let mut permissive_config = PcontigConfig {
level: 0,
max_buffers_threshold: 10, ..Default::default()
};
let permissive_rewritten = graph_rewrite(&matcher, idx_buf, &mut permissive_config);
drop(permissive_rewritten);
}
#[test]
fn test_config_custom_ratio_threshold() {
use morok_ir::types::ReduceOp;
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(40, DType::Float32, 1);
let outer_range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let reduce_range = ctx.new_range(&SInt::Const(5), AxisType::Reduce);
let indexed = UOp::index()
.buffer(buffer)
.indices(vec![outer_range.clone(), reduce_range.clone()])
.call()
.expect("Failed to create INDEX");
let compute = indexed.reduce(smallvec::smallvec![reduce_range], ReduceOp::Add);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, vec![outer_range.clone()], vec![outer_range], opts);
let mut strict_config = PcontigConfig {
out_in_ratio_threshold: 100.0, ..Default::default()
};
let strict_rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut strict_config);
assert!(Arc::ptr_eq(&strict_rewritten, &idx_buf), "Strict config should keep buffer (low ratio)");
let mut permissive_config = PcontigConfig {
out_in_ratio_threshold: 0.1, ..Default::default()
};
let permissive_rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut permissive_config);
drop(permissive_rewritten);
}
#[test]
fn test_config_level_0_vs_2() {
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(40, DType::Float32, 1);
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let indexed = UOp::index().buffer(buffer).indices(vec![range.clone()]).call().expect("Failed to create INDEX");
let one = UOp::native_const(1.0f32);
let compute = indexed.try_add(&one).expect("Failed to create ADD");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, vec![range.clone()], vec![range], opts);
let mut disabled_config = PcontigConfig { level: 0, ..Default::default() };
let disabled_rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut disabled_config);
assert!(Arc::ptr_eq(&disabled_rewritten, &idx_buf), "Level 0 should disable Pattern 4 optimizations");
let mut enabled_config = PcontigConfig {
level: 2,
out_in_ratio_threshold: 1.0, ..Default::default()
};
let enabled_rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut enabled_config);
assert!(!Arc::ptr_eq(&enabled_rewritten, &idx_buf), "Level 2 should enable Pattern 4 optimizations");
}
#[test]
fn test_config_different_configs_produce_different_results() {
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let (_buffers, ranges, compute) = create_multi_buffer_graph(&mut ctx, 3);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, ranges.clone(), ranges, opts);
let mut config1 = PcontigConfig { level: 0, max_buffers_threshold: 2, ..Default::default() };
let rewritten1 = graph_rewrite(&matcher, idx_buf.clone(), &mut config1);
assert!(Arc::ptr_eq(&rewritten1, &idx_buf), "Config1 should block 3 buffers (threshold=2)");
let mut config2 = PcontigConfig { level: 0, max_buffers_threshold: 5, ..Default::default() };
let rewritten2 = graph_rewrite(&matcher, idx_buf, &mut config2);
drop(rewritten2);
}
#[test]
fn test_pipeline_integration_full_rangeify() {
use crate::rangeify::rangeify;
let src = UOp::new_buffer(morok_device::DeviceSpec::Cpu, 6, DType::Float32);
let reshaped = src.try_reshape(&smallvec::smallvec![morok_ir::SInt::Const(2), morok_ir::SInt::Const(3)]).unwrap();
let permute = reshaped.try_permute(vec![1, 0]).unwrap();
let (result, _ctx) = rangeify(permute, None).expect("Rangeify should succeed");
assert_eq!(result.dtype(), DType::Float32);
}
#[test]
fn test_pipeline_multiple_patterns_in_sequence() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let const_val = UOp::native_const(1.0f32);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let inner = UOp::bufferize(const_val.clone(), vec![range.clone()], opts.clone());
let outer = UOp::bufferize(inner, vec![range], opts);
let rewritten1 = graph_rewrite(&matcher, outer.clone(), &mut config);
assert!(!Arc::ptr_eq(&rewritten1, &outer), "Pattern 3 should fire");
let rewritten2 = graph_rewrite(&matcher, rewritten1.clone(), &mut config);
drop(rewritten2);
}
#[test]
fn test_pipeline_preserves_graph_structure() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(40, DType::Float32, 1);
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let indexed = UOp::index().buffer(buffer).indices(vec![range.clone()]).call().expect("Failed to create INDEX");
let two = UOp::native_const(2.0f32);
let mul = indexed.try_mul(&two).expect("Failed to create MUL");
let one = UOp::native_const(1.0f32);
let add = mul.try_add(&one).expect("Failed to create ADD");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(add, vec![range.clone()], vec![range], opts);
let original_dtype = idx_buf.dtype();
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
assert_eq!(rewritten.dtype(), original_dtype, "Dtype should be preserved");
}
#[test]
fn test_pipeline_cheap_inline_interaction() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let const_val = UOp::native_const(5.0f32);
let neg = const_val.neg();
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(neg.clone(), vec![range], opts);
let rewritten = graph_rewrite(&matcher, bufferized.clone(), &mut config);
assert!(Arc::ptr_eq(&rewritten, &bufferized), "Bare BUFFERIZE(non-const) should not be removed");
}
#[test]
fn test_symbolic_buffer_size_handling() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let batch_size = UOp::define_var("batch".to_string(), 0, 128);
let symbolic_range = UOp::range_axis(batch_size, AxisId::Renumbered(0), AxisType::Loop);
let buffer = create_test_buffer(40, DType::Float32, 1);
let concrete_range = UOp::new(
Op::Range {
end: UOp::index_const(10),
axis_id: AxisId::Renumbered(1),
axis_type: AxisType::Loop,
deps: smallvec::SmallVec::new(),
},
DType::Index,
);
let indexed = UOp::index().buffer(buffer).indices(vec![concrete_range]).call().expect("Failed to create INDEX");
let one = UOp::native_const(1.0f32);
let compute = indexed.try_add(&one).expect("Failed to create ADD");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(compute, vec![symbolic_range.clone()], opts);
let idx_buf = UOp::index().buffer(bufferized).indices(vec![symbolic_range]).call().expect("Failed to create INDEX");
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
drop(rewritten);
}
#[test]
fn test_all_symbolic_sizes() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let n = UOp::define_var("n".to_string(), 0, 1024);
let m = UOp::define_var("m".to_string(), 0, 1024);
let range_n = UOp::range_axis(n, AxisId::Renumbered(0), AxisType::Loop);
let range_m = UOp::range_axis(m, AxisId::Renumbered(1), AxisType::Loop);
let buffer = create_test_buffer(4096, DType::Float32, 1);
let indexed = UOp::index().buffer(buffer).indices(vec![range_n.clone()]).call().expect("Failed to create INDEX");
let two = UOp::native_const(2.0f32);
let compute = indexed.try_mul(&two).expect("Failed to create MUL");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(compute, vec![range_m.clone()], opts);
let idx_buf = UOp::index().buffer(bufferized).indices(vec![range_m]).call().expect("Failed to create INDEX");
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
drop(rewritten);
}
#[test]
fn test_mixed_concrete_symbolic_sizes() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let concrete_range = UOp::new(
Op::Range {
end: UOp::index_const(10),
axis_id: AxisId::Renumbered(0),
axis_type: AxisType::Loop,
deps: smallvec::SmallVec::new(),
},
DType::Index,
);
let batch = UOp::define_var("batch".to_string(), 0, 64);
let symbolic_range = UOp::range_axis(batch, AxisId::Renumbered(1), AxisType::Loop);
let buffer = create_test_buffer(40, DType::Float32, 1);
let indexed =
UOp::index().buffer(buffer).indices(vec![concrete_range.clone()]).call().expect("Failed to create INDEX");
let one = UOp::native_const(1.0f32);
let compute = indexed.try_add(&one).expect("Failed to create ADD");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(compute, vec![concrete_range.clone(), symbolic_range.clone()], opts);
let idx_buf = UOp::index()
.buffer(bufferized)
.indices(vec![concrete_range, symbolic_range])
.call()
.expect("Failed to create INDEX");
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
drop(rewritten);
}
#[test]
fn test_complex_diamond_pattern() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(40, DType::Float32, 1);
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let indexed =
UOp::index().buffer(buffer.clone()).indices(vec![range.clone()]).call().expect("Failed to create INDEX");
let two = UOp::native_const(2.0f32);
let three = UOp::native_const(3.0f32);
let mul1 = indexed.try_mul(&two).expect("Failed to create MUL");
let mul2 = indexed.try_mul(&three).expect("Failed to create MUL");
let add = mul1.try_add(&mul2).expect("Failed to create ADD");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(add, vec![range.clone()], vec![range], opts);
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
drop(rewritten);
}
#[test]
fn test_complex_deep_computation_chain() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(40, DType::Float32, 1);
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let indexed = UOp::index().buffer(buffer).indices(vec![range.clone()]).call().expect("Failed to create INDEX");
let mut current = indexed;
for i in 1..=5 {
let const_val = UOp::native_const(i as f32);
current = current.try_add(&const_val).expect("Failed to create ADD");
}
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(current, vec![range.clone()], vec![range], opts);
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
drop(rewritten);
}
#[test]
fn test_complex_multiple_independent_buffers() {
let mut config = PcontigConfig { level: 0, ..PcontigConfig::default() };
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let mut adds = vec![];
for i in 0..5 {
let buf = create_test_buffer(40, DType::Float32, i);
let idx = UOp::index().buffer(buf).indices(vec![range.clone()]).call().expect("Failed to create INDEX");
adds.push(idx);
}
let mut compute = adds[0].clone();
for add in &adds[1..] {
compute = compute.try_add(add).expect("Failed to create ADD");
}
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, vec![range.clone()], vec![range], opts);
let rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut config);
assert!(Arc::ptr_eq(&rewritten, &idx_buf), "Should block optimization with 5 buffers");
}
#[test]
fn test_complex_multiple_sequential_reduces() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(8000, DType::Float32, 1);
let loop_range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let reduce_range1 = ctx.new_range(&SInt::Const(20), AxisType::Reduce);
let reduce_range2 = ctx.new_range(&SInt::Const(40), AxisType::Reduce);
let indexed1 = UOp::index()
.buffer(buffer.clone())
.indices(vec![loop_range.clone(), reduce_range1.clone()])
.call()
.expect("Failed to create INDEX");
let reduce1 = UOp::new(
Op::Reduce { src: indexed1, ranges: vec![reduce_range1].into(), reduce_op: morok_ir::ReduceOp::Add },
DType::Float32,
);
let indexed2 = UOp::index()
.buffer(buffer)
.indices(vec![loop_range.clone(), reduce_range2.clone()])
.call()
.expect("Failed to create INDEX");
let reduce2 = UOp::new(
Op::Reduce { src: indexed2, ranges: vec![reduce_range2].into(), reduce_op: morok_ir::ReduceOp::Max },
DType::Float32,
);
let combined = reduce1.try_add(&reduce2).expect("Failed to create ADD");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(combined, vec![loop_range.clone()], opts);
let idx_buf = UOp::index().buffer(bufferized).indices(vec![loop_range]).call().expect("Failed to create INDEX");
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
drop(rewritten);
}
#[test]
fn test_boundary_very_large_buffer() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let large_range = ctx.new_range(&SInt::Const(10000), AxisType::Loop);
let buffer = create_test_buffer(40000, DType::Float32, 1);
let indexed =
UOp::index().buffer(buffer).indices(vec![large_range.clone()]).call().expect("Failed to create INDEX");
let one = UOp::native_const(1.0f32);
let compute = indexed.try_add(&one).expect("Failed to create ADD");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, vec![large_range.clone()], vec![large_range], opts);
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
drop(rewritten);
}
#[test]
fn test_boundary_size_one_dimension() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let range1 = ctx.new_range(&SInt::Const(1), AxisType::Loop);
let buffer = create_test_buffer(4, DType::Float32, 1);
let indexed = UOp::index().buffer(buffer).indices(vec![range1.clone()]).call().expect("Failed to create INDEX");
let two = UOp::native_const(2.0f32);
let compute = indexed.try_mul(&two).expect("Failed to create MUL");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, vec![range1.clone()], vec![range1], opts);
let rewritten = graph_rewrite(&matcher, idx_buf, &mut config);
drop(rewritten);
}
#[test]
fn test_boundary_exact_threshold_values() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let mut ctx = IndexingContext::new();
let buffer = create_test_buffer(40, DType::Float32, 1);
let input_range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let output_range = ctx.new_range(&SInt::Const(102), AxisType::Loop);
let indexed = UOp::index().buffer(buffer).indices(vec![input_range]).call().expect("Failed to create INDEX");
let one = UOp::native_const(1.0f32);
let compute = indexed.try_add(&one).expect("Failed to create ADD");
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let idx_buf = create_index_bufferize(compute, vec![output_range.clone()], vec![output_range], opts);
let rewritten = graph_rewrite(&matcher, idx_buf.clone(), &mut config);
drop(rewritten);
}
#[test]
fn test_boundary_minimal_computation() {
let mut config = PcontigConfig::default();
let matcher = buffer_removal_with_pcontig();
let const_val = UOp::native_const(42.0f32);
let mut ctx = IndexingContext::new();
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let bufferized = UOp::bufferize(const_val, vec![range], opts);
let rewritten = graph_rewrite(&matcher, bufferized.clone(), &mut config);
assert!(!Arc::ptr_eq(&rewritten, &bufferized), "Const should be inlined via Pattern 1");
}