use std::collections::HashSet;
use std::sync::Arc;
use morok_device::DeviceSpec;
use morok_dtype::DType;
use morok_ir::{AddrSpace, AxisType, BufferizeOpts, Op, SInt, UOp, UOpKey};
use test_case::test_case;
use crate::rangeify::indexing::IndexingContext;
use crate::rangeify::patterns::{buffer_limit_patterns, extract_device_from_graph, is_elementwise};
use crate::rewrite::graph_rewrite;
fn create_test_buffer(size: usize, dtype: DType, id: usize, device: DeviceSpec) -> Arc<UOp> {
let unique = UOp::buffer_id(Some(id));
let device_op = UOp::device(device);
UOp::new(Op::Buffer { unique, device: device_op, size }, dtype)
}
fn create_multi_buffer_computation(num_buffers: usize, device: DeviceSpec) -> (Vec<Arc<UOp>>, Arc<UOp>) {
assert!(num_buffers > 0, "Must have at least one buffer");
let mut ctx = IndexingContext::new();
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let ranges = vec![range.clone()];
let mut buffers = Vec::new();
let buffer0 = create_test_buffer(40, DType::Float32, 0, device.clone());
let indexed0 = UOp::index().buffer(buffer0.clone()).indices(ranges.clone()).call().unwrap();
buffers.push(buffer0);
let mut computation = indexed0;
for i in 1..num_buffers {
let buffer = create_test_buffer(40, DType::Float32, i, device.clone());
let indexed = UOp::index().buffer(buffer.clone()).indices(ranges.clone()).call().unwrap();
computation = computation.try_add(&indexed).expect("Failed to create ADD");
buffers.push(buffer);
}
(buffers, computation)
}
#[allow(clippy::mutable_key_type)]
fn count_bufferizes(uop: &Arc<UOp>) -> usize {
let mut count = 0;
let mut stack = vec![uop.clone()];
let mut visited = HashSet::new();
while let Some(current) = stack.pop() {
let key = UOpKey(current.clone());
if !visited.insert(key) {
continue;
}
if matches!(current.op(), Op::Bufferize { .. }) {
count += 1;
}
for child in current.op().sources() {
stack.push(child);
}
}
count
}
#[allow(clippy::mutable_key_type, dead_code)]
fn count_accessed_buffers(uop: &Arc<UOp>) -> usize {
let mut buffers = Vec::new();
let mut visited = HashSet::new();
fn visit(uop: &Arc<UOp>, buffers: &mut Vec<Arc<UOp>>, visited: &mut HashSet<UOpKey>) {
let key = UOpKey(Arc::clone(uop));
if !visited.insert(key) {
return;
}
match uop.op() {
Op::Bufferize { opts, .. } if opts.addrspace == AddrSpace::Global => {
buffers.push(Arc::clone(uop));
return; }
Op::Buffer { .. } | Op::MStack { .. } | Op::MSelect { .. } => {
buffers.push(Arc::clone(uop));
}
_ => {}
}
for child in uop.op().sources() {
visit(&child, buffers, visited);
}
}
visit(uop, &mut buffers, &mut visited);
let mut seen = HashSet::new();
buffers.retain(|b| seen.insert(UOpKey(Arc::clone(b))));
buffers.len()
}
#[cfg(feature = "metal")]
#[test]
fn test_metal_limit_at_threshold() {
let device = DeviceSpec::Metal { device_id: 0 };
let (_, computation) = create_multi_buffer_computation(30, device.clone());
let matcher = buffer_limit_patterns(31);
let result = graph_rewrite(&matcher, computation.clone(), &mut ());
assert!(
Arc::ptr_eq(&result, &computation),
"Should not materialize when exactly at limit (30 buffers + 1 output = 31 total)"
);
}
#[cfg(feature = "metal")]
#[test]
fn test_metal_limit_exceeded() {
let device = DeviceSpec::Metal { device_id: 0 };
let (_, computation) = create_multi_buffer_computation(31, device.clone());
let before_count = count_bufferizes(&computation);
let matcher = buffer_limit_patterns(31);
let result = graph_rewrite(&matcher, computation.clone(), &mut ());
let after_count = count_bufferizes(&result);
assert!(
after_count > before_count,
"Should materialize when buffer limit exceeded (31 buffers + 1 output = 32 total)"
);
}
#[cfg(feature = "webgpu")]
#[test]
fn test_webgpu_limit_at_threshold() {
let device = DeviceSpec::WebGpu;
let (_, computation) = create_multi_buffer_computation(7, device);
let matcher = buffer_limit_patterns(8);
let result = graph_rewrite(&matcher, computation.clone(), &mut ());
assert!(
Arc::ptr_eq(&result, &computation),
"Should not materialize when exactly at limit (7 buffers + 1 output = 8 total)"
);
}
#[cfg(feature = "webgpu")]
#[test]
fn test_webgpu_limit_exceeded() {
let device = DeviceSpec::WebGpu;
let (_, computation) = create_multi_buffer_computation(8, device);
let before_count = count_bufferizes(&computation);
let matcher = buffer_limit_patterns(8);
let result = graph_rewrite(&matcher, computation.clone(), &mut ());
let after_count = count_bufferizes(&result);
assert!(
after_count > before_count,
"Should materialize when buffer limit exceeded (8 buffers + 1 output = 9 total)"
);
}
#[test]
fn test_cpu_no_limit() {
let device = DeviceSpec::Cpu;
let (_, computation) = create_multi_buffer_computation(100, device);
let before_count = count_bufferizes(&computation);
let result = computation.clone();
assert!(Arc::ptr_eq(&result, &computation), "CPU should have no buffer limit");
assert_eq!(count_bufferizes(&result), before_count, "CPU should not materialize buffers");
}
#[cfg(feature = "cuda")]
#[test]
fn test_cuda_no_limit() {
let device = DeviceSpec::Cuda { device_id: 0 };
let (_, computation) = create_multi_buffer_computation(100, device);
let before_count = count_bufferizes(&computation);
let result = computation.clone();
assert!(Arc::ptr_eq(&result, &computation), "CUDA should have no buffer limit");
assert_eq!(count_bufferizes(&result), before_count, "CUDA should not materialize buffers");
}
#[test]
fn test_binary_op_is_elementwise() {
let left = UOp::native_const(1.0f32);
let right = UOp::native_const(2.0f32);
let add = left.try_add(&right).unwrap();
assert!(is_elementwise(&add), "Binary ADD should be elementwise");
}
#[test]
fn test_ternary_op_is_elementwise() {
let cond = UOp::native_const(true);
let true_val = UOp::native_const(1.0f32);
let false_val = UOp::native_const(2.0f32);
let where_op = UOp::try_where(cond, true_val, false_val).unwrap();
assert!(is_elementwise(&where_op), "Ternary WHERE should be elementwise");
}
#[test]
fn test_non_elementwise_operations() {
let const_op = UOp::native_const(1.0f32);
assert!(!is_elementwise(&const_op), "CONST should not be elementwise");
let device = DeviceSpec::Cpu;
let buffer = create_test_buffer(100, DType::Float32, 1, device);
assert!(!is_elementwise(&buffer), "BUFFER should not be elementwise");
}
#[cfg(feature = "metal")]
#[test]
fn test_materialize_only_elementwise() {
let device = DeviceSpec::Metal { device_id: 0 };
let (_, computation) = create_multi_buffer_computation(31, device);
let matcher = buffer_limit_patterns(31);
let result = graph_rewrite(&matcher, computation, &mut ());
let bufferize_count = count_bufferizes(&result);
assert!(bufferize_count > 0, "Should have materialized elementwise operations");
}
#[test_case(30, false ; "at_limit_should_not_trigger")]
#[test_case(31, true ; "over_limit_should_trigger")]
fn test_output_buffer_accounting(num_buffers: usize, should_materialize: bool) {
let device = DeviceSpec::Cpu; let (_, computation) = create_multi_buffer_computation(num_buffers, device);
let before_count = count_bufferizes(&computation);
let matcher = buffer_limit_patterns(31); let result = graph_rewrite(&matcher, computation.clone(), &mut ());
let after_count = count_bufferizes(&result);
if should_materialize {
assert!(after_count > before_count, "Should materialize when num_buffers={} (> 30)", num_buffers);
} else {
assert_eq!(after_count, before_count, "Should not materialize when num_buffers={} (<= 30)", num_buffers);
}
}
#[test]
fn test_extract_device_no_device() {
let const_op = UOp::native_const(1.0f32);
assert_eq!(extract_device_from_graph(&const_op), None, "Should return None when no device");
}
#[test]
fn test_extract_device_from_device_op() {
let device_op = UOp::device(DeviceSpec::Cpu);
assert_eq!(extract_device_from_graph(&device_op), Some(DeviceSpec::Cpu), "Should extract CPU device");
}
#[test]
fn test_extract_device_from_buffer() {
let device = DeviceSpec::Cpu;
let buffer = create_test_buffer(100, DType::Float32, 1, device.clone());
assert_eq!(extract_device_from_graph(&buffer), Some(device), "Should extract device from BUFFER");
}
#[test]
fn test_no_double_materialization() {
let device = DeviceSpec::Cpu;
let (buffers, _) = create_multi_buffer_computation(35, device.clone());
let mut ctx = IndexingContext::new();
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let ranges = vec![range.clone()];
let indexed1 = UOp::index().buffer(buffers[0].clone()).indices(ranges.clone()).call().unwrap();
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let materialized = UOp::bufferize(indexed1, ranges.clone(), opts);
let indexed_materialized = UOp::index().buffer(materialized).indices(ranges).call().unwrap();
let before_count = count_bufferizes(&indexed_materialized);
let matcher = buffer_limit_patterns(31);
let result = graph_rewrite(&matcher, indexed_materialized, &mut ());
let after_count = count_bufferizes(&result);
assert_eq!(before_count, after_count, "Should not double-materialize already-materialized operations");
}
#[test]
fn test_integration_with_rangeify_pipeline() {
let device = DeviceSpec::Cpu;
let (_, computation) = create_multi_buffer_computation(10, device);
let result = crate::rangeify::rangeify(computation.clone(), None);
assert!(result.is_ok(), "Rangeify pipeline should complete successfully with buffer limit enforcement");
}
#[test]
fn test_multiple_binary_ops() {
let device = DeviceSpec::Cpu;
let mut ctx = IndexingContext::new();
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let ranges = vec![range.clone()];
let mut buffers = Vec::new();
for i in 0..20 {
buffers.push(create_test_buffer(40, DType::Float32, i, device.clone()));
}
let indexed0 = UOp::index().buffer(buffers[0].clone()).indices(ranges.clone()).call().unwrap();
let mut expr = indexed0;
for buffer in buffers.iter().skip(1) {
let indexed = UOp::index().buffer(buffer.clone()).indices(ranges.clone()).call().unwrap();
expr = expr.try_add(&indexed).expect("Failed to create ADD");
}
let before_count = count_bufferizes(&expr);
let matcher = buffer_limit_patterns(10);
let result = graph_rewrite(&matcher, expr, &mut ());
let after_count = count_bufferizes(&result);
assert!(after_count > before_count, "Should materialize intermediate results to stay within buffer limit");
}
#[test]
fn test_ternary_op_materialization() {
let device = DeviceSpec::Cpu;
let mut ctx = IndexingContext::new();
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let ranges = vec![range.clone()];
let mut buffers = Vec::new();
for i in 0..15 {
buffers.push(create_test_buffer(40, DType::Float32, i, device.clone()));
}
let indexed0 = UOp::index().buffer(buffers[0].clone()).indices(ranges.clone()).call().unwrap();
let indexed1 = UOp::index().buffer(buffers[1].clone()).indices(ranges.clone()).call().unwrap();
let mut cond = indexed1.try_cmplt(&indexed0).unwrap();
for i in (2..10).step_by(2) {
let left = UOp::index().buffer(buffers[i].clone()).indices(ranges.clone()).call().unwrap();
let right = UOp::index().buffer(buffers[i + 1].clone()).indices(ranges.clone()).call().unwrap();
let cmp = right.try_cmplt(&left).unwrap(); cond = cond.try_and_op(&cmp).unwrap();
}
let true_val = UOp::index().buffer(buffers[10].clone()).indices(ranges.clone()).call().unwrap();
let false_val = {
let b11 = UOp::index().buffer(buffers[11].clone()).indices(ranges.clone()).call().unwrap();
let b12 = UOp::index().buffer(buffers[12].clone()).indices(ranges.clone()).call().unwrap();
b11.try_add(&b12).expect("Failed to create ADD")
};
let where_op = UOp::try_where(cond, true_val, false_val).unwrap();
let before_count = count_bufferizes(&where_op);
let matcher = buffer_limit_patterns(10);
let result = graph_rewrite(&matcher, where_op, &mut ());
let after_count = count_bufferizes(&result);
assert!(after_count > before_count, "Should materialize intermediate results in ternary operations");
}
#[test]
fn test_is_elementwise() {
let left = UOp::native_const(1.0f32);
let right = UOp::native_const(2.0f32);
let add = left.try_add(&right).unwrap();
assert!(is_elementwise(&add), "Binary ADD should be elementwise");
let cond = UOp::native_const(true);
let true_val = UOp::native_const(1.0f32);
let false_val = UOp::native_const(2.0f32);
let where_op = UOp::try_where(cond, true_val, false_val).unwrap();
assert!(is_elementwise(&where_op), "Ternary WHERE should be elementwise");
let const_op = UOp::native_const(1.0f32);
assert!(!is_elementwise(&const_op), "CONST should not be elementwise");
}