use crate::optimizer::{Opt, Renderer, Scheduler, apply_opt};
use crate::test::helpers::*;
use morok_ir::{AxisType, ReduceOp};
#[test]
fn test_upcasts() {
let pattern = create_elementwise_pattern(&[16, 16]);
let renderer = Renderer::cpu();
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
let result = apply_opt(&mut sched, &Opt::upcast(0, 2), true);
assert!(result.is_ok(), "UPCAST by 2 should succeed: {:?}", result.err());
assert_eq!(sched.shape_len(), 3, "Should have 3 axes after upcast by 2");
assert_axis_count(&sched, AxisType::Upcast, 1);
assert_axis_count(&sched, AxisType::Global, 2);
let upcast_axes = sched.axes_of(&[AxisType::Upcast]);
assert_eq!(upcast_axes.len(), 1, "Should have exactly one upcast axis");
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
let result = apply_opt(&mut sched, &Opt::upcast(0, 4), true);
assert!(result.is_ok(), "UPCAST by 4 should succeed: {:?}", result.err());
assert_eq!(sched.shape_len(), 3, "Should have 3 axes after upcast by 4");
assert_axis_count(&sched, AxisType::Upcast, 1);
assert_axis_count(&sched, AxisType::Global, 2);
}
{
let mut sched = Scheduler::new(pattern, renderer);
let result = apply_opt(&mut sched, &Opt::upcast(0, 8), true);
assert!(result.is_ok(), "UPCAST by 8 should succeed: {:?}", result.err());
assert_eq!(sched.shape_len(), 3, "Should have 3 axes after upcast by 8");
assert_axis_count(&sched, AxisType::Upcast, 1);
assert_axis_count(&sched, AxisType::Global, 2);
}
}
#[test]
fn test_full_upcast() {
let pattern = create_elementwise_pattern(&[4]);
let renderer = Renderer::cpu();
let mut sched = Scheduler::new(pattern, renderer);
let result = apply_opt(&mut sched, &Opt::upcast(0, 4), true);
assert!(result.is_ok(), "Full UPCAST should succeed: {:?}", result.err());
assert_eq!(sched.shape_len(), 1, "Should have 1 axis after full upcast (Global(1) filtered)");
assert_shape_equal(&sched, &[4]);
assert_axes_equal(&sched, &[AxisType::Upcast]);
assert_axis_count(&sched, AxisType::Upcast, 1);
assert_axis_count(&sched, AxisType::Global, 0);
}
#[test]
fn test_local_and_grouped_reduce() {
let pattern = create_reduce_with_globals(&[4, 4, 128], 128, ReduceOp::Add);
let renderer = Renderer::cuda();
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
let result = apply_opt(&mut sched, &Opt::local(0, 2), true);
assert!(result.is_ok(), "LOCAL by 2 should succeed: {:?}", result.err());
assert_axis_count(&sched, AxisType::Local, 1);
assert_axis_count(&sched, AxisType::Global, 3);
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
let result = apply_opt(&mut sched, &Opt::local(2, 8), true);
assert!(result.is_ok(), "LOCAL(2, 8) should succeed: {:?}", result.err());
assert_axis_count(&sched, AxisType::Local, 1);
assert_axis_count(&sched, AxisType::Global, 3);
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
let result = apply_opt(&mut sched, &Opt::local(2, 16), true);
assert!(result.is_ok(), "LOCAL(2, 16) should succeed: {:?}", result.err());
assert_axis_count(&sched, AxisType::Local, 1);
assert_axis_count(&sched, AxisType::Global, 3);
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
let result = apply_opt(&mut sched, &Opt::grouptop(0, 2), true);
assert!(result.is_ok(), "GROUPTOP by 2 should succeed: {:?}", result.err());
assert_axis_count(&sched, AxisType::GroupReduce, 1);
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
let result = apply_opt(&mut sched, &Opt::grouptop(0, 32), true);
assert!(result.is_ok(), "GROUPTOP by 32 should succeed");
assert_axis_count(&sched, AxisType::GroupReduce, 1);
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
let result = apply_opt(&mut sched, &Opt::grouptop(0, 64), true);
assert!(result.is_ok(), "GROUPTOP by 64 should succeed");
assert_axis_count(&sched, AxisType::GroupReduce, 1);
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
let result = apply_opt(&mut sched, &Opt::local(0, 2), true);
assert!(result.is_ok(), "LOCAL should succeed");
let result = apply_opt(&mut sched, &Opt::grouptop(0, 2), true);
assert!(result.is_ok(), "GROUPTOP after LOCAL should succeed");
assert_axis_count(&sched, AxisType::Local, 1);
assert_axis_count(&sched, AxisType::GroupReduce, 1);
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
apply_opt(&mut sched, &Opt::local(2, 16), true).unwrap();
apply_opt(&mut sched, &Opt::grouptop(0, 16), true).unwrap();
assert_axis_count(&sched, AxisType::Local, 1);
assert_axis_count(&sched, AxisType::GroupReduce, 1);
}
{
let mut sched = Scheduler::new(pattern, renderer);
apply_opt(&mut sched, &Opt::local(2, 2), true).unwrap();
apply_opt(&mut sched, &Opt::grouptop(0, 2), true).unwrap();
apply_opt(&mut sched, &Opt::upcast(0, 2), true).unwrap();
let unrollable = sched.unrollable_dims();
if !unrollable.is_empty() {
let _ = apply_opt(&mut sched, &Opt::unroll(0, 2), true);
}
assert!(!sched.axes_of(&[AxisType::Local]).is_empty());
assert!(!sched.axes_of(&[AxisType::Unroll]).is_empty());
assert!(!sched.axes_of(&[AxisType::Upcast]).is_empty());
}
}
#[test]
fn test_double_reduce() {
let pattern = create_double_reduce_with_globals(&[8, 8], &[128, 128], ReduceOp::Add);
let renderer = Renderer::cuda();
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
let result = apply_opt(&mut sched, &Opt::grouptop(0, 2), true);
assert!(result.is_ok(), "GROUPTOP(0, 2) should succeed: {:?}", result.err());
assert_axis_count(&sched, AxisType::GroupReduce, 1);
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
let result = apply_opt(&mut sched, &Opt::grouptop(0, 32), true);
assert!(result.is_ok(), "GROUPTOP(0, 32) should succeed");
assert_axis_count(&sched, AxisType::GroupReduce, 1);
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
let result = apply_opt(&mut sched, &Opt::grouptop(1, 2), true);
assert!(result.is_ok(), "GROUPTOP(1, 2) should succeed");
assert_axis_count(&sched, AxisType::GroupReduce, 1);
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
let result = apply_opt(&mut sched, &Opt::grouptop(1, 32), true);
assert!(result.is_ok(), "GROUPTOP(1, 32) should succeed");
assert_axis_count(&sched, AxisType::GroupReduce, 1);
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
apply_opt(&mut sched, &Opt::grouptop(0, 2), true).unwrap();
apply_opt(&mut sched, &Opt::grouptop(1, 2), true).unwrap();
assert_axis_count(&sched, AxisType::GroupReduce, 2);
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
apply_opt(&mut sched, &Opt::grouptop(0, 16), true).unwrap();
apply_opt(&mut sched, &Opt::grouptop(1, 2), true).unwrap();
assert_axis_count(&sched, AxisType::GroupReduce, 2);
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
apply_opt(&mut sched, &Opt::grouptop(0, 4), true).unwrap();
apply_opt(&mut sched, &Opt::grouptop(1, 64), true).unwrap();
assert_axis_count(&sched, AxisType::GroupReduce, 2);
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
apply_opt(&mut sched, &Opt::grouptop(0, 16), true).unwrap();
apply_opt(&mut sched, &Opt::grouptop(1, 2), true).unwrap();
apply_opt(&mut sched, &Opt::unroll(0, 4), true).unwrap();
assert!(!sched.axes_of(&[AxisType::GroupReduce, AxisType::Reduce]).is_empty());
assert!(!sched.axes_of(&[AxisType::Unroll]).is_empty());
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
apply_opt(&mut sched, &Opt::grouptop(0, 2), true).unwrap();
apply_opt(&mut sched, &Opt::grouptop(1, 32), true).unwrap();
apply_opt(&mut sched, &Opt::unroll(2, 4), true).unwrap();
assert!(!sched.axes_of(&[AxisType::GroupReduce, AxisType::Reduce]).is_empty());
assert!(!sched.axes_of(&[AxisType::Unroll]).is_empty());
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
apply_opt(&mut sched, &Opt::local(0, 4), true).unwrap();
apply_opt(&mut sched, &Opt::local(1, 4), true).unwrap();
apply_opt(&mut sched, &Opt::grouptop(0, 4), true).unwrap();
apply_opt(&mut sched, &Opt::grouptop(1, 4), true).unwrap();
assert_axis_count(&sched, AxisType::Local, 2);
assert_axis_count(&sched, AxisType::GroupReduce, 2);
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
apply_opt(&mut sched, &Opt::local(0, 4), true).unwrap();
apply_opt(&mut sched, &Opt::local(1, 4), true).unwrap();
apply_opt(&mut sched, &Opt::grouptop(0, 2), true).unwrap();
apply_opt(&mut sched, &Opt::grouptop(1, 32), true).unwrap();
apply_opt(&mut sched, &Opt::unroll(1, 4), true).unwrap();
assert_axis_count(&sched, AxisType::Local, 2);
assert!(!sched.axes_of(&[AxisType::GroupReduce, AxisType::Reduce]).is_empty());
assert!(!sched.axes_of(&[AxisType::Unroll]).is_empty());
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
apply_opt(&mut sched, &Opt::local(0, 2), true).unwrap();
apply_opt(&mut sched, &Opt::local(1, 2), true).unwrap();
apply_opt(&mut sched, &Opt::grouptop(0, 8), true).unwrap();
apply_opt(&mut sched, &Opt::grouptop(1, 4), true).unwrap();
apply_opt(&mut sched, &Opt::upcast(0, 2), true).unwrap();
assert_axis_count(&sched, AxisType::Local, 2);
assert_axis_count(&sched, AxisType::GroupReduce, 2);
assert_axis_count(&sched, AxisType::Upcast, 1);
}
{
let mut sched = Scheduler::new(pattern.clone(), renderer.clone());
apply_opt(&mut sched, &Opt::local(0, 2), true).unwrap();
apply_opt(&mut sched, &Opt::local(1, 2), true).unwrap();
apply_opt(&mut sched, &Opt::grouptop(0, 8), true).unwrap();
apply_opt(&mut sched, &Opt::grouptop(1, 4), true).unwrap();
apply_opt(&mut sched, &Opt::upcast(0, 2), true).unwrap();
apply_opt(&mut sched, &Opt::unroll(0, 4), true).unwrap();
apply_opt(&mut sched, &Opt::unroll(1, 4), true).unwrap();
assert_axis_count(&sched, AxisType::Local, 2);
assert!(!sched.axes_of(&[AxisType::GroupReduce, AxisType::Reduce]).is_empty());
assert_axis_count(&sched, AxisType::Upcast, 1);
assert!(!sched.axes_of(&[AxisType::Unroll]).is_empty());
}
{
let mut sched = Scheduler::new(pattern, renderer);
apply_opt(&mut sched, &Opt::local(0, 4), true).unwrap();
apply_opt(&mut sched, &Opt::local(1, 4), true).unwrap();
apply_opt(&mut sched, &Opt::grouptop(0, 4), true).unwrap();
apply_opt(&mut sched, &Opt::grouptop(1, 4), true).unwrap();
apply_opt(&mut sched, &Opt::upcast(0, 2), true).unwrap();
apply_opt(&mut sched, &Opt::upcast(0, 2), true).unwrap();
assert_axis_count(&sched, AxisType::Global, 0);
assert_axis_count(&sched, AxisType::Local, 2);
assert_axis_count(&sched, AxisType::GroupReduce, 2);
assert_axis_count(&sched, AxisType::Upcast, 2);
}
}