#![allow(clippy::identity_op)]
use std::sync::Arc;
use morok_device::DeviceSpec;
use morok_dtype::DType;
use morok_ir::{Op, SInt, UOp};
use smallvec::SmallVec;
use crate::rangeify::kernel::{SplitReduceOpConfig, collect_range_ids, split_reduceop};
#[test]
fn test_config_default() {
let config = SplitReduceOpConfig::default();
assert_eq!(config.split_threshold, 32768);
assert_eq!(config.output_size_bits, 22);
assert_eq!(config.max_divisor, 256);
assert_eq!(config.min_divisor, 8);
assert!(config.enabled);
}
#[test]
fn test_config_max_output_size() {
let config = SplitReduceOpConfig::default();
assert_eq!(config.max_output_size(), 4_194_304); }
#[test]
fn test_config_custom() {
let config = SplitReduceOpConfig { split_threshold: 65536, output_size_bits: 20, ..Default::default() };
assert_eq!(config.split_threshold, 65536);
assert_eq!(config.output_size_bits, 20);
assert_eq!(config.max_output_size(), 1_048_576); }
#[test]
fn test_collect_range_ids_empty() {
let const_val = UOp::native_const(1.0f32);
let ids = collect_range_ids(&const_val);
assert_eq!(ids, Vec::<usize>::new());
}
#[test]
fn test_collect_range_ids_single() {
let range = UOp::range_const(10, 0);
let ids = collect_range_ids(&range);
assert_eq!(ids, vec![0]);
}
#[test]
fn test_collect_range_ids_multiple() {
let r0 = UOp::range_const(10, 0);
let r1 = UOp::range_const(5, 1);
let r2 = UOp::range_const(3, 2);
let add = r0.try_add(&r1).unwrap();
let mul = add.try_mul(&r2).unwrap();
let ids = collect_range_ids(&mul);
assert_eq!(ids, vec![0, 1, 2]);
}
#[test]
fn test_collect_range_ids_unsorted() {
let r2 = UOp::range_const(3, 2);
let r0 = UOp::range_const(10, 0);
let r1 = UOp::range_const(5, 1);
let expr = r2.try_add(&r0).unwrap().try_add(&r1).unwrap();
let ids = collect_range_ids(&expr);
assert_eq!(ids, vec![0, 1, 2]); }
fn create_test_tensor(shape: &[usize]) -> Arc<UOp> {
let total_size: usize = shape.iter().product();
let buffer = UOp::new_buffer(DeviceSpec::Cpu, total_size, DType::Float32);
if shape.len() == 1 {
buffer
} else {
let shape_sint: SmallVec<[SInt; 4]> = shape.iter().map(|&s| SInt::Const(s)).collect();
buffer.try_reshape(&shape_sint).unwrap()
}
}
#[test]
fn test_split_reduceop_disabled() {
let config = SplitReduceOpConfig { enabled: false, ..Default::default() };
let tensor = create_test_tensor(&[100000]);
let reduce = tensor.try_reduce_axis(morok_ir::ReduceOp::Add, vec![0]).unwrap();
assert!(split_reduceop(&reduce, &config).is_none());
}
#[test]
fn test_split_reduceop_below_threshold() {
let config = SplitReduceOpConfig::default();
let tensor = create_test_tensor(&[1000]);
let reduce = tensor.try_reduce_axis(morok_ir::ReduceOp::Add, vec![0]).unwrap();
assert!(split_reduceop(&reduce, &config).is_none());
}
#[test]
fn test_split_reduceop_basic_split() {
let config = SplitReduceOpConfig::default();
let tensor = create_test_tensor(&[100000]);
let reduce = tensor.try_reduce_axis(morok_ir::ReduceOp::Add, vec![0]).unwrap();
let result = split_reduceop(&reduce, &config);
assert!(result.is_some(), "Should split large reduction");
let transformed = result.unwrap();
let original_shape = reduce.shape().unwrap().unwrap();
let result_shape = transformed.shape().unwrap().unwrap();
assert_eq!(result_shape.len(), original_shape.len());
}
#[test]
fn test_split_reduceop_preserves_reduce_op() {
let config = SplitReduceOpConfig::default();
for reduce_op in
[morok_ir::ReduceOp::Add, morok_ir::ReduceOp::Mul, morok_ir::ReduceOp::Max, morok_ir::ReduceOp::Min]
{
let tensor = create_test_tensor(&[100000]);
let reduce = tensor.try_reduce_axis(reduce_op, vec![0]).unwrap();
let result = split_reduceop(&reduce, &config);
assert!(result.is_some(), "Should split large reduction for {:?}", reduce_op);
let transformed = result.unwrap();
let has_reduce_op = transformed.toposort().iter().any(|node| {
matches!(
node.op(),
Op::ReduceAxis { reduce_op: op, .. } if *op == reduce_op
)
});
assert!(has_reduce_op, "Reduce op {:?} should be preserved", reduce_op);
}
}
#[test]
fn test_split_reduceop_has_contiguous() {
let config = SplitReduceOpConfig::default();
let tensor = create_test_tensor(&[100000]);
let reduce = tensor.try_reduce_axis(morok_ir::ReduceOp::Add, vec![0]).unwrap();
let result = split_reduceop(&reduce, &config);
assert!(result.is_some(), "Should split large reduction");
let transformed = result.unwrap();
let has_contiguous = transformed.toposort().iter().any(|node| matches!(node.op(), Op::Contiguous { .. }));
assert!(has_contiguous, "Should have CONTIGUOUS for intermediate materialization");
}
#[test]
fn test_split_reduceop_multidim_below_threshold() {
let config = SplitReduceOpConfig::default();
let tensor = create_test_tensor(&[1000, 1000]);
let reduce = tensor.try_reduce_axis(morok_ir::ReduceOp::Add, vec![1]).unwrap();
let result = split_reduceop(&reduce, &config);
assert!(result.is_none(), "Should NOT split - ratio too low");
}
#[test]
fn test_split_reduceop_multidim_above_threshold() {
let config = SplitReduceOpConfig::default();
let tensor = create_test_tensor(&[1000, 100000]);
let reduce = tensor.try_reduce_axis(morok_ir::ReduceOp::Add, vec![1]).unwrap();
let result = split_reduceop(&reduce, &config);
assert!(result.is_some(), "Should split large multidim reduction");
let transformed = result.unwrap();
let original_shape = reduce.shape().unwrap().unwrap();
let result_shape = transformed.shape().unwrap().unwrap();
assert_eq!(result_shape.len(), original_shape.len());
}
#[test]
fn test_split_with_expand_detects_broadcast() {
let config = SplitReduceOpConfig::default();
use morok_ir::Op;
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100 * 1 * 1000, DType::Float32);
let reshaped =
buffer.try_reshape(&vec![SInt::Const(100), SInt::Const(1), SInt::Const(1000)].into_iter().collect()).unwrap();
let expand_shape =
UOp::vectorize(vec![UOp::index_const(100), UOp::index_const(500), UOp::index_const(1000)].into());
let expanded = UOp::new(Op::Expand { src: reshaped, new_shape: expand_shape }, DType::Float32);
let reduce = expanded.try_reduce_axis(morok_ir::ReduceOp::Add, vec![1]).unwrap();
let result = split_reduceop(&reduce, &config);
assert!(result.is_none(), "Should NOT split - axis 1 is broadcast (expanded)");
}
#[test]
fn test_split_with_nested_movement_ops() {
let config = SplitReduceOpConfig::default();
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 50 * 1, DType::Float32);
let reshaped1 = buffer.try_reshape(&vec![SInt::Const(50), SInt::Const(1)].into_iter().collect()).unwrap();
let expand_shape = UOp::vectorize(vec![UOp::index_const(50), UOp::index_const(1000)].into());
let expanded = UOp::new(morok_ir::Op::Expand { src: reshaped1, new_shape: expand_shape }, DType::Float32);
let reshaped2 = expanded.try_reshape(&vec![SInt::Const(50000)].into_iter().collect()).unwrap();
let reduce = reshaped2.try_reduce_axis(morok_ir::ReduceOp::Add, vec![0]).unwrap();
let result = split_reduceop(&reduce, &config);
assert!(result.is_some(), "Should split - movement patterns should allow split on nested operations");
let transformed = result.unwrap();
let has_contiguous = transformed.toposort().iter().any(|node| matches!(node.op(), morok_ir::Op::Contiguous { .. }));
assert!(has_contiguous, "Split result should have CONTIGUOUS");
}
#[test]
fn test_split_skips_expanded_dimensions() {
let config = SplitReduceOpConfig::default();
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100 * 1 * 100000, DType::Float32);
let reshaped =
buffer.try_reshape(&vec![SInt::Const(100), SInt::Const(1), SInt::Const(100000)].into_iter().collect()).unwrap();
let expand_shape =
UOp::vectorize(vec![UOp::index_const(100), UOp::index_const(50), UOp::index_const(100000)].into());
let expanded = UOp::new(morok_ir::Op::Expand { src: reshaped, new_shape: expand_shape }, DType::Float32);
let reduce = expanded.try_reduce_axis(morok_ir::ReduceOp::Add, vec![2]).unwrap();
let result = split_reduceop(&reduce, &config);
assert!(result.is_some(), "Should split on axis 2 - it's not expanded (only axis 1 is expanded)");
let transformed = result.unwrap();
let has_contiguous = transformed.toposort().iter().any(|node| matches!(node.op(), morok_ir::Op::Contiguous { .. }));
assert!(has_contiguous, "Should have split successfully");
}