use std::sync::Arc;
use morok_ir::{AxisId, AxisType, ConstValue, Op, ReduceOp, UOp};
use crate::optimizer::error::OptError;
use crate::optimizer::{OptOps, Renderer, Scheduler};
#[test]
fn test_scheduler_new() {
let ast = UOp::native_const(1.0f32);
let ren = Renderer::cpu();
let scheduler = Scheduler::new(ast, ren);
assert_eq!(scheduler.applied_opts.len(), 0);
assert!(!scheduler.dont_use_locals);
assert_eq!(scheduler.shape_len(), 0); }
#[test]
fn test_scheduler_rngs_sorting() {
let end_16 = UOp::index_const(16);
let end_8 = UOp::index_const(8);
let end_32 = UOp::index_const(32);
let end_4 = UOp::index_const(4);
let r_global = UOp::range_axis(end_16, AxisId::Renumbered(0), AxisType::Global);
let r_local = UOp::range_axis(end_8, AxisId::Renumbered(1), AxisType::Local);
let r_reduce = UOp::range_axis(end_32, AxisId::Renumbered(2), AxisType::Reduce);
let r_loop = UOp::range_axis(end_4, AxisId::Renumbered(3), AxisType::Loop);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global, r_local, r_reduce, r_loop]);
let ren = Renderer::cpu();
let scheduler = Scheduler::new(sink, ren);
let rngs = scheduler.rngs();
assert_eq!(rngs.len(), 4);
if let Op::Range { axis_type, .. } = rngs[0].op() {
assert_eq!(*axis_type, AxisType::Loop);
}
if let Op::Range { axis_type, .. } = rngs[1].op() {
assert_eq!(*axis_type, AxisType::Global);
}
if let Op::Range { axis_type, .. } = rngs[2].op() {
assert_eq!(*axis_type, AxisType::Local);
}
if let Op::Range { axis_type, .. } = rngs[3].op() {
assert_eq!(*axis_type, AxisType::Reduce);
}
}
#[test]
fn test_scheduler_maxarg() {
let r1 = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(5), AxisType::Loop);
let r2 = UOp::range_axis(UOp::index_const(20), AxisId::Renumbered(2), AxisType::Global);
let r3 = UOp::range_axis(UOp::index_const(30), AxisId::Renumbered(10), AxisType::Reduce);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r1, r2, r3]);
let ren = Renderer::cpu();
let scheduler = Scheduler::new(sink, ren);
assert_eq!(scheduler.maxarg(), 10); }
#[test]
fn test_scheduler_helper_properties() {
let end_8 = UOp::index_const(8);
let end_16 = UOp::index_const(16);
let end_32 = UOp::index_const(32);
let r_global = UOp::range_axis(end_16.clone(), AxisId::Renumbered(0), AxisType::Global);
let r_local = UOp::range_axis(end_8.clone(), AxisId::Renumbered(1), AxisType::Local);
let r_reduce = UOp::range_axis(end_32.clone(), AxisId::Renumbered(2), AxisType::Reduce);
let value = UOp::native_const(1.0f32);
let reduce_op = value.clone().reduce(vec![r_reduce.clone()].into(), ReduceOp::Add);
let sink = UOp::sink(vec![reduce_op, r_global, r_local, r_reduce]);
let ren = Renderer::cpu();
let scheduler = Scheduler::new(sink, ren);
assert!(scheduler.reduceop().is_some());
assert_eq!(scheduler.reduceops().len(), 1);
let output = scheduler.output_shape();
assert_eq!(output.len(), 2); assert_eq!(output[0], 16); assert_eq!(output[1], 8);
assert_eq!(scheduler.upcast_size(), 1);
assert_eq!(scheduler.group_for_reduces(), 0);
assert_eq!(scheduler.bufs().len(), 0);
}
#[test]
fn test_scheduler_upcast_size() {
let end_4 = UOp::index_const(4);
let end_8 = UOp::index_const(8);
let r_upcast1 = UOp::range_axis(end_4, AxisId::Renumbered(0), AxisType::Upcast);
let r_upcast2 = UOp::range_axis(end_8, AxisId::Renumbered(1), AxisType::Upcast);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_upcast1, r_upcast2]);
let ren = Renderer::cpu();
let scheduler = Scheduler::new(sink, ren);
assert_eq!(scheduler.upcast_size(), 32);
}
#[test]
fn test_scheduler_group_for_reduces() {
let end_16 = UOp::index_const(16);
let end_32 = UOp::index_const(32);
let r_group = UOp::range_axis(end_16, AxisId::Renumbered(0), AxisType::GroupReduce);
let r_reduce = UOp::range_axis(end_32, AxisId::Renumbered(1), AxisType::Reduce);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_group, r_reduce]);
let ren = Renderer::cpu();
let scheduler = Scheduler::new(sink, ren);
assert_eq!(scheduler.group_for_reduces(), 1);
}
#[test]
fn test_scheduler_axes_of() {
let end_16 = UOp::index_const(16);
let end_8 = UOp::index_const(8);
let end_32 = UOp::index_const(32);
let r_global = UOp::range_axis(end_16, AxisId::Renumbered(0), AxisType::Global);
let r_local = UOp::range_axis(end_8, AxisId::Renumbered(1), AxisType::Local);
let r_reduce = UOp::range_axis(end_32, AxisId::Renumbered(2), AxisType::Reduce);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global, r_local, r_reduce]);
let ren = Renderer::cpu();
let scheduler = Scheduler::new(sink, ren);
let global_axes = scheduler.axes_of(&[AxisType::Global]);
assert_eq!(global_axes, vec![0]);
let reduce_axes = scheduler.axes_of(&[AxisType::Reduce]);
assert_eq!(reduce_axes, vec![2]);
let parallel_axes = scheduler.axes_of(&[AxisType::Global, AxisType::Local]);
assert_eq!(parallel_axes, vec![0, 1]);
let reduce_rngs = scheduler.ranges_of(&[AxisType::Reduce]);
assert_eq!(reduce_rngs.len(), 1);
if let Op::Range { axis_type, .. } = reduce_rngs[0].op() {
assert_eq!(*axis_type, AxisType::Reduce);
}
}
#[test]
fn test_scheduler_upcastable_dims() {
let end_1 = UOp::index_const(1);
let end_16 = UOp::index_const(16);
let end_32 = UOp::index_const(32);
let r_global = UOp::range_axis(end_16.clone(), AxisId::Renumbered(0), AxisType::Global);
let r_loop = UOp::range_axis(end_32, AxisId::Renumbered(1), AxisType::Loop);
let r_reduce = UOp::range_axis(end_16.clone(), AxisId::Renumbered(2), AxisType::Reduce);
let r_size1 = UOp::range_axis(end_1, AxisId::Renumbered(3), AxisType::Global);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global, r_loop, r_reduce, r_size1]);
let ren = Renderer::cpu();
let scheduler = Scheduler::new(sink, ren);
let upcastable = scheduler.upcastable_dims();
assert_eq!(upcastable.len(), 2);
assert!(upcastable.contains(&0)); assert!(upcastable.contains(&1)); }
#[test]
fn test_scheduler_unrollable_dims() {
let end_16 = UOp::index_const(16);
let end_32 = UOp::index_const(32);
let end_1 = UOp::index_const(1);
let r_global = UOp::range_axis(end_16.clone(), AxisId::Renumbered(0), AxisType::Global);
let r_reduce1 = UOp::range_axis(end_32, AxisId::Renumbered(1), AxisType::Reduce);
let r_reduce2 = UOp::range_axis(end_16, AxisId::Renumbered(2), AxisType::Reduce);
let r_reduce_size1 = UOp::range_axis(end_1, AxisId::Renumbered(3), AxisType::Reduce);
let value = UOp::native_const(1.0f32);
let reduce_op = value.reduce(vec![r_reduce1.clone(), r_reduce2.clone()].into(), ReduceOp::Add);
let sink = UOp::sink(vec![reduce_op, r_global, r_reduce1, r_reduce2, r_reduce_size1]);
let ren = Renderer::cpu();
let scheduler = Scheduler::new(sink, ren);
let unrollable = scheduler.unrollable_dims();
assert_eq!(unrollable.len(), 2);
assert!(unrollable.contains(&1)); assert!(unrollable.contains(&2)); }
#[test]
fn test_scheduler_real_axis() {
let end_16 = UOp::index_const(16);
let end_32 = UOp::index_const(32);
let r_global = UOp::range_axis(end_16.clone(), AxisId::Renumbered(0), AxisType::Global);
let r_loop = UOp::range_axis(end_16.clone(), AxisId::Renumbered(1), AxisType::Loop);
let r_reduce1 = UOp::range_axis(end_32.clone(), AxisId::Renumbered(2), AxisType::Reduce);
let r_reduce2 = UOp::range_axis(end_16, AxisId::Renumbered(3), AxisType::Reduce);
let value = UOp::native_const(1.0f32);
let reduce_op = value.reduce(vec![r_reduce1.clone(), r_reduce2.clone()].into(), ReduceOp::Add);
let sink = UOp::sink(vec![reduce_op, r_global, r_loop, r_reduce1, r_reduce2]);
let ren = Renderer::cpu();
let scheduler = Scheduler::new(sink, ren);
assert_eq!(scheduler.real_axis(OptOps::UPCAST, Some(1)).unwrap(), 1);
assert_eq!(scheduler.real_axis(OptOps::LOCAL, Some(0)).unwrap(), 0);
assert_eq!(scheduler.real_axis(OptOps::UNROLL, Some(0)).unwrap(), 2); assert_eq!(scheduler.real_axis(OptOps::UNROLL, Some(1)).unwrap(), 3);
assert_eq!(scheduler.real_axis(OptOps::GROUP, Some(0)).unwrap(), 2);
assert_eq!(scheduler.real_axis(OptOps::GROUP, Some(1)).unwrap(), 3);
assert_eq!(scheduler.real_axis(OptOps::TC, None).unwrap(), -1);
assert_eq!(scheduler.real_axis(OptOps::NOLOCALS, None).unwrap(), -1);
assert!(scheduler.real_axis(OptOps::UPCAST, Some(10)).is_err());
assert!(scheduler.real_axis(OptOps::UNROLL, Some(5)).is_err());
}
#[test]
fn test_scheduler_colored_shape() {
let end_16 = UOp::index_const(16);
let end_8 = UOp::index_const(8);
let end_32 = UOp::index_const(32);
let end_4 = UOp::index_const(4);
let r_global = UOp::range_axis(end_16, AxisId::Renumbered(0), AxisType::Global);
let r_local = UOp::range_axis(end_8, AxisId::Renumbered(1), AxisType::Local);
let r_reduce = UOp::range_axis(end_32, AxisId::Renumbered(2), AxisType::Reduce);
let r_upcast = UOp::range_axis(end_4, AxisId::Renumbered(3), AxisType::Upcast);
let value = UOp::native_const(1.0f32);
let reduce_op = value.reduce(vec![r_reduce.clone()].into(), ReduceOp::Add);
let sink = UOp::sink(vec![reduce_op, r_global, r_local, r_reduce, r_upcast]);
let ren = Renderer::cpu();
let scheduler = Scheduler::new(sink, ren);
let shape = scheduler.colored_shape();
assert_eq!(shape, "g16l8u4R32");
let shape_vec = scheduler.shape_str();
assert_eq!(shape_vec, vec!["g16", "l8", "u4", "R32"]);
assert_eq!(scheduler.kernel_type(), "r");
let display_str = format!("{}", scheduler);
assert_eq!(display_str, "r_g16l8u4R32");
}
#[test]
fn test_scheduler_display_elementwise() {
let end_256 = UOp::index_const(256);
let r_global1 = UOp::range_axis(end_256.clone(), AxisId::Renumbered(0), AxisType::Global);
let r_global2 = UOp::range_axis(end_256, AxisId::Renumbered(1), AxisType::Global);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global1, r_global2]);
let ren = Renderer::cpu();
let scheduler = Scheduler::new(sink, ren);
assert_eq!(scheduler.kernel_type(), "E");
let display_str = format!("{}", scheduler);
assert_eq!(display_str, "E_g256g256");
}
#[test]
fn test_scheduler_display_complex() {
let end_2 = UOp::index_const(2);
let end_4 = UOp::index_const(4);
let end_8 = UOp::index_const(8);
let end_16 = UOp::index_const(16);
let end_32 = UOp::index_const(32);
let r_loop = UOp::range_axis(end_2, AxisId::Renumbered(0), AxisType::Loop);
let r_global = UOp::range_axis(end_32.clone(), AxisId::Renumbered(1), AxisType::Global);
let r_local = UOp::range_axis(end_16, AxisId::Renumbered(2), AxisType::Local);
let r_reduce = UOp::range_axis(end_32, AxisId::Renumbered(3), AxisType::Reduce);
let r_upcast = UOp::range_axis(end_4, AxisId::Renumbered(4), AxisType::Upcast);
let r_unroll = UOp::range_axis(end_8, AxisId::Renumbered(5), AxisType::Unroll);
let value = UOp::native_const(1.0f32);
let reduce_op = value.reduce(vec![r_reduce.clone(), r_unroll.clone()].into(), ReduceOp::Add);
let sink = UOp::sink(vec![reduce_op, r_loop, r_global, r_local, r_reduce, r_upcast, r_unroll]);
let ren = Renderer::cpu();
let scheduler = Scheduler::new(sink, ren);
let shape = scheduler.colored_shape();
assert_eq!(shape, "L2g32l16u4R32r8");
let display_str = format!("{}", scheduler);
assert_eq!(display_str, "r_L2g32l16u4R32r8");
}
#[test]
fn test_shift_to_basic_split() {
let end_16 = UOp::index_const(16);
let r_global = UOp::range_axis(end_16, AxisId::Renumbered(0), AxisType::Global);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global.clone()]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
assert_eq!(scheduler.shape_len(), 1);
assert_eq!(scheduler.maxarg(), 0);
let result = scheduler.shift_to(r_global.clone(), 4, AxisType::Upcast, false, None);
assert!(result.is_ok());
let (replaced_rng, new_rng) = result.unwrap();
if let Op::Range { end, axis_id, axis_type, .. } = replaced_rng.op() {
assert_eq!(axis_id, &AxisId::Renumbered(0)); assert_eq!(axis_type, &AxisType::Global); if let Op::Const(cv) = end.op()
&& let ConstValue::Int(sz) = cv.0
{
assert_eq!(sz, 4);
} else {
panic!("Expected constant size");
}
} else {
panic!("Expected Range operation");
}
if let Op::Range { end, axis_id, axis_type, .. } = new_rng.op() {
assert_eq!(axis_id, &AxisId::Renumbered(1)); assert_eq!(axis_type, &AxisType::Upcast);
if let Op::Const(cv) = end.op()
&& let ConstValue::Int(sz) = cv.0
{
assert_eq!(sz, 4);
} else {
panic!("Expected constant size");
}
} else {
panic!("Expected Range operation");
}
assert_eq!(scheduler.shape_len(), 2);
assert_eq!(scheduler.maxarg(), 1);
}
#[test]
fn test_shift_to_top_order() {
let end_16 = UOp::index_const(16);
let r_global = UOp::range_axis(end_16, AxisId::Renumbered(0), AxisType::Global);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global.clone()]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let result = scheduler.shift_to(r_global.clone(), 4, AxisType::Local, true, None);
assert!(result.is_ok());
let (replaced_rng, new_rng) = result.unwrap();
if let Op::Range { end, .. } = replaced_rng.op()
&& let Op::Const(cv) = end.op()
&& let ConstValue::Int(sz) = cv.0
{
assert_eq!(sz, 4); } else {
panic!("Expected constant size");
}
if let Op::Range { end, axis_type, .. } = new_rng.op() {
assert_eq!(axis_type, &AxisType::Local);
if let Op::Const(cv) = end.op()
&& let ConstValue::Int(sz) = cv.0
{
assert_eq!(sz, 4);
} else {
panic!("Expected constant size");
}
} else {
panic!("Expected Range operation");
}
assert_eq!(scheduler.shape_len(), 2);
}
#[test]
fn test_shift_to_division_error() {
let end_15 = UOp::index_const(15);
let r_global = UOp::range_axis(end_15, AxisId::Renumbered(0), AxisType::Global);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global.clone()]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let result = scheduler.shift_to(r_global.clone(), 4, AxisType::Upcast, false, None);
assert!(result.is_err());
if let Err(e) = result {
assert!(matches!(e, OptError::DivisionError { .. }));
}
}
#[test]
fn test_shift_to_substitution_in_ast() {
let end_16 = UOp::index_const(16);
let r_global = UOp::range_axis(end_16.clone(), AxisId::Renumbered(0), AxisType::Global);
let two = UOp::index_const(2);
let compute = r_global.try_mul(&two).unwrap();
let sink = UOp::sink(vec![compute.clone(), r_global.clone()]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let result = scheduler.shift_to(r_global.clone(), 4, AxisType::Upcast, false, None);
assert!(result.is_ok());
let new_rngs = scheduler.rngs();
assert_eq!(new_rngs.len(), 2);
let all_nodes = scheduler.ast().toposort();
let ranges_with_axis0: Vec<_> = all_nodes
.iter()
.filter_map(|node| {
if let Op::Range { end, axis_id, .. } = node.op()
&& *axis_id == AxisId::Renumbered(0)
&& let Op::Const(cv) = end.op()
&& let ConstValue::Int(sz) = cv.0
{
return Some(sz);
}
None
})
.collect();
assert_eq!(ranges_with_axis0.len(), 1);
assert_eq!(ranges_with_axis0[0], 4);
}
#[test]
fn test_shift_to_cache_invalidation() {
let end_16 = UOp::index_const(16);
let r_global = UOp::range_axis(end_16, AxisId::Renumbered(0), AxisType::Global);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global.clone()]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let _rngs_before = scheduler.rngs();
let _maxarg_before = scheduler.maxarg();
let _shape_len_before = scheduler.shape_len();
let result = scheduler.shift_to(r_global.clone(), 4, AxisType::Upcast, false, None);
assert!(result.is_ok());
let rngs_after = scheduler.rngs();
assert_eq!(rngs_after.len(), 2);
let maxarg_after = scheduler.maxarg();
assert_eq!(maxarg_after, 1);
let shape_len_after = scheduler.shape_len();
assert_eq!(shape_len_after, 2); }
#[test]
fn test_shift_to_with_custom_range() {
let end_16 = UOp::index_const(16);
let r_global = UOp::range_axis(end_16.clone(), AxisId::Renumbered(0), AxisType::Global);
let end_4 = UOp::index_const(4);
let custom_rng = UOp::range_axis(end_4, AxisId::Renumbered(99), AxisType::Upcast);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global.clone()]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let result = scheduler.shift_to(r_global.clone(), 4, AxisType::Upcast, false, Some(custom_rng.clone()));
assert!(result.is_ok());
let (_replaced_rng, new_rng) = result.unwrap();
if let Op::Range { axis_id, .. } = new_rng.op() {
assert_eq!(axis_id, &AxisId::Renumbered(99)); } else {
panic!("Expected Range operation");
}
}
#[test]
fn test_shift_to_multiple_splits() {
let end_64 = UOp::index_const(64);
let r_global = UOp::range_axis(end_64, AxisId::Renumbered(0), AxisType::Global);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global.clone()]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let result1 = scheduler.shift_to(r_global.clone(), 4, AxisType::Upcast, false, None);
assert!(result1.is_ok());
let (global_16, _upcast_4) = result1.unwrap();
assert_eq!(scheduler.shape_len(), 2);
assert_eq!(scheduler.maxarg(), 1);
let result2 = scheduler.shift_to(global_16, 2, AxisType::Local, false, None);
assert!(result2.is_ok());
assert_eq!(scheduler.shape_len(), 3);
assert_eq!(scheduler.maxarg(), 2);
let final_rngs = scheduler.rngs();
assert_eq!(final_rngs.len(), 3);
}
use crate::optimizer::{Opt, OptArg, apply_opt};
fn get_axis_type(uop: &UOp) -> AxisType {
if let Op::Range { axis_type, .. } = uop.op() {
*axis_type
} else {
panic!("Expected Range operation");
}
}
#[test]
fn test_upcast_basic() {
let end_16 = UOp::index_const(16);
let r_global = UOp::range_axis(end_16, AxisId::Renumbered(0), AxisType::Global);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let opt = Opt::upcast(0, 4);
let result = apply_opt(&mut scheduler, &opt, true);
assert!(result.is_ok());
assert_eq!(scheduler.applied_opts.len(), 1);
assert_eq!(scheduler.applied_opts[0], opt);
assert_eq!(scheduler.shape_len(), 2);
let rngs = scheduler.rngs();
assert_eq!(get_axis_type(&rngs[0]), AxisType::Global);
assert_eq!(get_axis_type(&rngs[1]), AxisType::Upcast);
}
#[test]
fn test_upcast_invalid_axis_type() {
let end_32 = UOp::index_const(32);
let r_reduce = UOp::range_axis(end_32, AxisId::Renumbered(0), AxisType::Reduce);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_reduce]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let opt = Opt::upcast(0, 4);
let result = apply_opt(&mut scheduler, &opt, false);
assert!(result.is_err());
if let Err(e) = result {
assert!(matches!(e, OptError::ValidationFailed { op: "UPCAST", .. }));
}
}
#[test]
fn test_upcast_device_limit() {
let end_256 = UOp::index_const(256);
let r_global = UOp::range_axis(end_256, AxisId::Renumbered(0), AxisType::Global);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let opt = Opt::upcast(0, 32);
let result = apply_opt(&mut scheduler, &opt, false);
assert!(result.is_err());
if let Err(e) = result {
assert!(matches!(e, OptError::DeviceLimitExceeded { limit_type: "upcast", .. }));
}
}
#[test]
fn test_local_basic() {
let end_64 = UOp::index_const(64);
let r_global = UOp::range_axis(end_64, AxisId::Renumbered(0), AxisType::Global);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global]);
let ren = Renderer::cuda(); let mut scheduler = Scheduler::new(sink, ren);
let opt = Opt::local(0, 8);
let result = apply_opt(&mut scheduler, &opt, true);
assert!(result.is_ok());
assert_eq!(scheduler.applied_opts.len(), 1);
assert_eq!(scheduler.shape_len(), 2);
let rngs = scheduler.rngs();
assert_eq!(get_axis_type(&rngs[0]), AxisType::Global);
assert_eq!(get_axis_type(&rngs[1]), AxisType::Local);
}
#[test]
fn test_local_no_backend_support() {
let end_64 = UOp::index_const(64);
let r_global = UOp::range_axis(end_64, AxisId::Renumbered(0), AxisType::Global);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let opt = Opt::local(0, 8);
let result = apply_opt(&mut scheduler, &opt, false);
assert!(result.is_err());
if let Err(e) = result {
assert!(matches!(e, OptError::UnsupportedFeature { feature: "local memory" }));
}
}
#[test]
fn test_local_invalid_axis_type() {
let end_32 = UOp::index_const(32);
let r_reduce = UOp::range_axis(end_32, AxisId::Renumbered(0), AxisType::Reduce);
let compute = UOp::native_const(1.0f32);
let reduce = compute.reduce(vec![r_reduce].into(), ReduceOp::Add);
let sink = UOp::sink(vec![reduce]);
let ren = Renderer::cuda();
let mut scheduler = Scheduler::new(sink, ren);
let opt = Opt::local(0, 4);
let result = apply_opt(&mut scheduler, &opt, false);
assert!(result.is_err());
if let Err(e) = result {
assert!(matches!(e, OptError::ValidationFailed { op: "LOCAL", .. }));
}
}
#[test]
fn test_unroll_basic() {
let end_32 = UOp::index_const(32);
let r_reduce = UOp::range_axis(end_32, AxisId::Renumbered(0), AxisType::Reduce);
let compute = UOp::native_const(1.0f32);
let reduce = compute.reduce(vec![r_reduce].into(), ReduceOp::Add);
let sink = UOp::sink(vec![reduce]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let unrollable = scheduler.unrollable_dims();
assert_eq!(unrollable.len(), 1);
let opt = Opt::unroll(0, 4);
let result = apply_opt(&mut scheduler, &opt, true);
assert!(result.is_ok());
assert_eq!(scheduler.applied_opts.len(), 1);
assert_eq!(scheduler.shape_len(), 2);
let rngs = scheduler.rngs();
assert_eq!(get_axis_type(&rngs[0]), AxisType::Reduce);
assert_eq!(get_axis_type(&rngs[1]), AxisType::Unroll);
}
#[test]
fn test_unroll_axis_out_of_bounds() {
let end_32 = UOp::index_const(32);
let r_reduce = UOp::range_axis(end_32, AxisId::Renumbered(0), AxisType::Reduce);
let compute = UOp::native_const(1.0f32);
let reduce = compute.reduce(vec![r_reduce].into(), ReduceOp::Add);
let sink = UOp::sink(vec![reduce]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let opt = Opt::unroll(1, 4);
let result = apply_opt(&mut scheduler, &opt, false);
assert!(result.is_err());
if let Err(e) = result {
assert!(matches!(e, OptError::AxisOutOfBounds { .. }));
}
}
#[test]
fn test_unroll_excessive_amount() {
let end_128 = UOp::index_const(128);
let r_reduce = UOp::range_axis(end_128, AxisId::Renumbered(0), AxisType::Reduce);
let compute = UOp::native_const(1.0f32);
let reduce = compute.reduce(vec![r_reduce].into(), ReduceOp::Add);
let sink = UOp::sink(vec![reduce]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let opt = Opt::unroll(0, 64);
let result = apply_opt(&mut scheduler, &opt, false);
assert!(result.is_err());
if let Err(e) = result {
assert!(matches!(e, OptError::DeviceLimitExceeded { limit_type: "unroll", .. }));
}
}
#[test]
fn test_apply_opt_multiple_operations() {
let end_64 = UOp::index_const(64);
let r_global = UOp::range_axis(end_64.clone(), AxisId::Renumbered(0), AxisType::Global);
let end_32 = UOp::index_const(32);
let r_reduce = UOp::range_axis(end_32, AxisId::Renumbered(1), AxisType::Reduce);
let compute = UOp::native_const(1.0f32);
let reduce = compute.reduce(vec![r_global, r_reduce].into(), ReduceOp::Add);
let sink = UOp::sink(vec![reduce]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let opt1 = Opt::upcast(0, 4); assert!(apply_opt(&mut scheduler, &opt1, true).is_ok());
let opt2 = Opt::unroll(0, 8); assert!(apply_opt(&mut scheduler, &opt2, true).is_ok());
assert_eq!(scheduler.applied_opts.len(), 2);
assert_eq!(scheduler.applied_opts[0], opt1);
assert_eq!(scheduler.applied_opts[1], opt2);
assert_eq!(scheduler.shape_len(), 4);
}
#[test]
fn test_apply_opt_invalid_arg_type() {
let end_16 = UOp::index_const(16);
let r_global = UOp::range_axis(end_16, AxisId::Renumbered(0), AxisType::Global);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let opt = Opt::new(OptOps::UPCAST, Some(0), OptArg::TensorCore { tc_select: 0, opt_level: 0, use_tc: 0 });
let result = apply_opt(&mut scheduler, &opt, false);
assert!(result.is_err());
if let Err(e) = result {
assert!(matches!(e, OptError::InvalidArgType { expected: "Int", .. }));
}
}
#[test]
fn test_nolocals_basic() {
let end_16 = UOp::index_const(16);
let r_global = UOp::range_axis(end_16, AxisId::Renumbered(0), AxisType::Global);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global]);
let ren = Renderer::cuda();
let mut scheduler = Scheduler::new(sink, ren);
let opt = Opt::nolocals();
let result = apply_opt(&mut scheduler, &opt, true);
assert!(result.is_ok());
assert!(scheduler.dont_use_locals);
let opt_local = Opt::local(0, 4);
let result = apply_opt(&mut scheduler, &opt_local, false);
assert!(result.is_err());
if let Err(e) = result {
assert!(matches!(e, OptError::ValidationFailed { op: "LOCAL", .. }));
}
}
#[test]
fn test_nolocals_with_existing_local() {
let end_64 = UOp::index_const(64);
let r_global = UOp::range_axis(end_64.clone(), AxisId::Renumbered(0), AxisType::Global);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global]);
let ren = Renderer::cuda();
let mut scheduler = Scheduler::new(sink, ren);
let opt_local = Opt::local(0, 8);
assert!(apply_opt(&mut scheduler, &opt_local, true).is_ok());
let opt = Opt::nolocals();
let result = apply_opt(&mut scheduler, &opt, false);
assert!(result.is_err());
if let Err(e) = result {
assert!(matches!(e, OptError::ValidationFailed { op: "NOLOCALS", .. }));
}
}
#[test]
fn test_swap_basic() {
let end_16 = UOp::index_const(16);
let end_32 = UOp::index_const(32);
let r_global1 = UOp::range_axis(end_16, AxisId::Renumbered(0), AxisType::Global);
let r_global2 = UOp::range_axis(end_32, AxisId::Renumbered(1), AxisType::Global);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global1, r_global2]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let rngs_before = scheduler.rngs();
let get_size = |rng: &Arc<UOp>| -> i64 {
if let Op::Range { end, .. } = rng.op()
&& let Op::Const(cv) = end.op()
&& let morok_ir::ConstValue::Int(sz) = cv.0
{
return sz;
}
panic!("Expected Range with constant size");
};
let size0_before = get_size(&rngs_before[0]);
let size1_before = get_size(&rngs_before[1]);
let opt = Opt::swap(0, 1); let result = apply_opt(&mut scheduler, &opt, true);
assert!(result.is_ok());
let rngs_after = scheduler.rngs();
let size0_after = get_size(&rngs_after[0]);
let size1_after = get_size(&rngs_after[1]);
assert_eq!(size0_after, size1_before, "axis_id 0 should now have the size that axis_id 1 had");
assert_eq!(size1_after, size0_before, "axis_id 1 should now have the size that axis_id 0 had");
}
#[test]
fn test_swap_invalid_axis() {
let end_16 = UOp::index_const(16);
let end_32 = UOp::index_const(32);
let r_global1 = UOp::range_axis(end_16, AxisId::Renumbered(0), AxisType::Global);
let r_global2 = UOp::range_axis(end_32, AxisId::Renumbered(1), AxisType::Global);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_global1, r_global2]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let opt = Opt::swap(0, 5);
let result = apply_opt(&mut scheduler, &opt, false);
assert!(result.is_err());
if let Err(e) = result {
assert!(matches!(e, OptError::AxisOutOfBounds { .. }));
}
}
#[test]
fn test_swap_non_global_axis() {
let end_16 = UOp::index_const(16);
let end_32 = UOp::index_const(32);
let r_global = UOp::range_axis(end_16, AxisId::Renumbered(0), AxisType::Global);
let r_reduce = UOp::range_axis(end_32, AxisId::Renumbered(1), AxisType::Reduce);
let compute = UOp::native_const(1.0f32);
let reduce = compute.reduce(vec![r_global.clone(), r_reduce].into(), ReduceOp::Add);
let sink = UOp::sink(vec![reduce]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let opt = Opt::swap(0, 1);
let result = apply_opt(&mut scheduler, &opt, false);
assert!(result.is_err());
if let Err(e) = result {
assert!(matches!(e, OptError::ValidationFailed { op: "SWAP", .. }));
}
}
#[test]
fn test_group_basic() {
let end_64 = UOp::index_const(64);
let r_reduce = UOp::range_axis(end_64, AxisId::Renumbered(0), AxisType::Reduce);
let compute = UOp::native_const(1.0f32);
let reduce = compute.reduce(vec![r_reduce].into(), ReduceOp::Add);
let sink = UOp::sink(vec![reduce]);
let ren = Renderer::cuda();
let mut scheduler = Scheduler::new(sink, ren);
let opt = Opt::group(0, 8); let result = apply_opt(&mut scheduler, &opt, true);
assert!(result.is_ok());
assert_eq!(scheduler.shape_len(), 2);
let rngs = scheduler.rngs();
assert_eq!(get_axis_type(&rngs[0]), AxisType::GroupReduce);
assert_eq!(get_axis_type(&rngs[1]), AxisType::Reduce);
}
#[test]
fn test_group_no_shared_memory() {
let end_64 = UOp::index_const(64);
let r_reduce = UOp::range_axis(end_64, AxisId::Renumbered(0), AxisType::Reduce);
let compute = UOp::native_const(1.0f32);
let reduce = compute.reduce(vec![r_reduce].into(), ReduceOp::Add);
let sink = UOp::sink(vec![reduce]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let opt = Opt::group(0, 8);
let result = apply_opt(&mut scheduler, &opt, false);
assert!(result.is_err());
if let Err(e) = result {
assert!(matches!(e, OptError::UnsupportedFeature { .. }));
}
}
#[test]
fn test_convert_loop_to_global_gpu() {
let loop1 = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(0), AxisType::Loop);
let loop2 = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(1), AxisType::Loop);
let val = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![val, loop1.clone(), loop2.clone()]);
let ren = Renderer::cuda();
let mut scheduler = Scheduler::new(sink, ren);
scheduler.convert_loop_to_global().unwrap();
let ranges = scheduler.rngs();
assert_eq!(ranges.len(), 2);
for rng in ranges {
if let Op::Range { axis_type, .. } = rng.op() {
assert_eq!(*axis_type, AxisType::Global);
} else {
panic!("Expected RANGE operation");
}
}
}
#[test]
fn test_convert_loop_to_global_cpu() {
let loop1 = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(0), AxisType::Loop);
let val = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![val, loop1.clone()]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
scheduler.convert_loop_to_global().unwrap();
let ranges = scheduler.rngs();
assert_eq!(ranges.len(), 1);
if let Op::Range { axis_type, .. } = ranges[0].op() {
assert_eq!(*axis_type, AxisType::Loop);
}
}
#[test]
fn test_get_optimized_ast_reduce_kernel() {
let r_global = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(0), AxisType::Global);
let r_local = UOp::range_axis(UOp::index_const(8), AxisId::Renumbered(1), AxisType::Local);
let r_reduce = UOp::range_axis(UOp::index_const(32), AxisId::Renumbered(2), AxisType::Reduce);
let r_upcast = UOp::range_axis(UOp::index_const(4), AxisId::Renumbered(3), AxisType::Upcast);
let val = UOp::native_const(1.0f32);
let reduce = val.reduce(vec![r_reduce.clone()].into(), ReduceOp::Add);
let sink = UOp::sink(vec![reduce, r_global, r_local, r_upcast]);
let ren = Renderer::cuda();
let scheduler = Scheduler::new(sink, ren);
let optimized = scheduler.get_optimized_ast(None);
use crate::optimizer::KernelInfo;
let info = optimized.metadata::<KernelInfo>();
assert!(info.is_some());
let info = info.unwrap();
assert!(info.name.starts_with("r_"));
assert!(info.name.contains("g16"));
assert!(info.name.contains("l8"));
assert!(info.name.contains("R32"));
assert!(info.name.contains("u4"));
}
#[test]
fn test_get_optimized_ast_elementwise_kernel() {
let r_global = UOp::range_axis(UOp::index_const(256), AxisId::Renumbered(0), AxisType::Global);
let val = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![val, r_global]);
let ren = Renderer::cuda();
let scheduler = Scheduler::new(sink, ren);
let optimized = scheduler.get_optimized_ast(None);
use crate::optimizer::KernelInfo;
let info = optimized.metadata::<KernelInfo>();
assert!(info.is_some());
let info = info.unwrap();
assert!(info.name.starts_with("E_"));
assert!(info.name.contains("g256"));
}
#[test]
fn test_get_optimized_ast_custom_name() {
let r_global = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(0), AxisType::Global);
let val = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![val, r_global]);
let ren = Renderer::cuda();
let scheduler = Scheduler::new(sink, ren);
let optimized = scheduler.get_optimized_ast(Some("custom_kernel".to_string()));
use crate::optimizer::KernelInfo;
let info = optimized.metadata::<KernelInfo>();
assert!(info.is_some());
let info = info.unwrap();
assert_eq!(info.name, "custom_kernel");
}
#[test]
fn test_phase7_integration() {
let loop1 = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(0), AxisType::Loop);
let loop2 = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(1), AxisType::Loop);
let val = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![val, loop1, loop2]);
let ren = Renderer::cuda();
let mut scheduler = Scheduler::new(sink, ren);
scheduler.convert_loop_to_global().unwrap();
let opt = Opt::upcast(0, 4);
apply_opt(&mut scheduler, &opt, true).unwrap();
let optimized = scheduler.get_optimized_ast(None);
use crate::optimizer::KernelInfo;
let info = optimized.metadata::<KernelInfo>();
assert!(info.is_some());
let info = info.unwrap();
assert_eq!(info.applied_opts.len(), 1);
assert_eq!(info.applied_opts[0].op, OptOps::UPCAST);
}
#[test]
fn test_kernel_name_deduplication() {
use crate::optimizer::{KernelInfo, clear_kernel_name_counts};
clear_kernel_name_counts();
let r_global = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(0), AxisType::Global);
let val = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![val, r_global]);
let ren = Renderer::cuda();
let scheduler1 = Scheduler::new(sink.clone(), ren.clone());
let scheduler2 = Scheduler::new(sink.clone(), ren.clone());
let scheduler3 = Scheduler::new(sink.clone(), ren);
let opt1 = scheduler1.get_optimized_ast(None);
let opt2 = scheduler2.get_optimized_ast(None);
let opt3 = scheduler3.get_optimized_ast(None);
let info1 = opt1.metadata::<KernelInfo>().unwrap();
let info2 = opt2.metadata::<KernelInfo>().unwrap();
let info3 = opt3.metadata::<KernelInfo>().unwrap();
assert_ne!(info1.name, info2.name, "Second kernel should have different name than first");
assert_ne!(info2.name, info3.name, "Third kernel should have different name than second");
assert_ne!(info1.name, info3.name, "Third kernel should have different name than first");
assert!(info1.name.starts_with("E_g16"), "First kernel name should start with E_g16");
assert!(info2.name.starts_with("E_g16"), "Second kernel name should start with E_g16");
assert!(info3.name.starts_with("E_g16"), "Third kernel name should start with E_g16");
assert!(info2.name.contains('n'), "Second kernel should have deduplication suffix");
assert!(info3.name.contains('n'), "Third kernel should have deduplication suffix");
clear_kernel_name_counts();
}
#[test]
fn test_globalizable_rngs_with_sink() {
let loop1 = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(0), AxisType::Loop);
let loop2 = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(1), AxisType::Loop);
let val = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![val, loop1.clone(), loop2.clone()]);
let ren = Renderer::cuda();
let mut scheduler = Scheduler::new(sink, ren);
scheduler.convert_loop_to_global().unwrap();
let ranges = scheduler.rngs();
assert_eq!(ranges.len(), 2);
for rng in ranges {
if let Op::Range { axis_type, .. } = rng.op() {
assert_eq!(*axis_type, AxisType::Global, "LOOP axes in SINK should be converted to GLOBAL");
}
}
}
#[test]
fn test_flatten_ranges_store() {
let r_reduce = UOp::range_axis(UOp::index_const(32), AxisId::Renumbered(0), AxisType::Reduce);
let val = UOp::native_const(1.0f32);
let reduce = val.clone().reduce(vec![r_reduce].into(), ReduceOp::Add);
let index = UOp::index_const(0); let store = index.store(reduce);
let ren = Renderer::cuda();
let scheduler = Scheduler::new(store, ren);
let optimized = scheduler.get_optimized_ast(None);
use crate::optimizer::KernelInfo;
let info = optimized.metadata::<KernelInfo>();
assert!(info.is_some(), "STORE with nested REDUCE should be flattened successfully");
}
#[test]
fn test_thread_basic() {
let end_64 = UOp::index_const(64);
let r_loop = UOp::range_axis(end_64, AxisId::Renumbered(0), AxisType::Loop);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_loop]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let thread_count = std::thread::available_parallelism().map(|p| p.get()).unwrap_or(4);
let opt = Opt::thread(0, thread_count);
let result = apply_opt(&mut scheduler, &opt, true);
assert!(result.is_ok(), "THREAD opt should succeed: {:?}", result);
assert_eq!(scheduler.shape_len(), 2);
let rngs = scheduler.rngs();
let types: Vec<AxisType> = rngs.iter().map(|r| get_axis_type(r)).collect();
assert!(types.contains(&AxisType::Thread), "Should have Thread axis: {:?}", types);
assert!(types.contains(&AxisType::Loop), "Should have Loop axis: {:?}", types);
}
#[test]
fn test_apply_threading_heuristic_loop() {
use crate::optimizer::heuristics::apply_threading;
let end = UOp::index_const(262144);
let r_loop = UOp::range_axis(end, AxisId::Renumbered(0), AxisType::Loop);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_loop]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let applied = apply_threading(&mut scheduler, 2);
assert!(applied, "apply_threading should succeed on Loop axis with sufficient work");
let thread_axes = scheduler.axes_of(&[AxisType::Thread]);
assert!(!thread_axes.is_empty(), "Should have Thread axis after apply_threading");
}
#[test]
fn test_apply_threading_heuristic_outer_not_threaded() {
use crate::optimizer::heuristics::apply_threading;
let end_512 = UOp::index_const(524288); let r_outer = UOp::range_axis(end_512, AxisId::Renumbered(0), AxisType::Outer);
let compute = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![compute, r_outer]);
let ren = Renderer::cpu();
let mut scheduler = Scheduler::new(sink, ren);
let applied = apply_threading(&mut scheduler, 2);
assert!(!applied, "apply_threading should NOT succeed on Outer axis (only Loop axes are threaded)");
let thread_axes = scheduler.axes_of(&[AxisType::Thread]);
assert!(thread_axes.is_empty(), "Should NOT have Thread axis for Outer");
}