use crate::optimizer::{Renderer, Scheduler, tc::*};
use morok_ir::{AxisId, AxisType, ReduceOp, UOp};
#[test]
fn test_detect_matmul_basic() {
let i = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(0), AxisType::Global);
let j = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(1), AxisType::Global);
let k = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(2), AxisType::Reduce);
let a_val = UOp::native_const(1.0f32);
let b_val = UOp::native_const(2.0f32);
let mul = a_val.try_mul(&b_val).unwrap();
let reduce = mul.reduce(vec![k].into(), ReduceOp::Add);
let sink = UOp::sink(vec![reduce, i, j]);
let ren = Renderer::cuda();
let scheduler = Scheduler::new(sink, ren);
let result = matching::detect_matmul(&scheduler);
assert!(result.is_ok());
}
#[test]
fn test_detect_matmul_no_reduce() {
let val = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![val]);
let ren = Renderer::cuda();
let scheduler = Scheduler::new(sink, ren);
let result = matching::detect_matmul(&scheduler);
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_detect_matmul_not_mul() {
let k = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(0), AxisType::Reduce);
let val = UOp::native_const(1.0f32);
let reduce = val.reduce(vec![k].into(), ReduceOp::Add);
let sink = UOp::sink(vec![reduce]);
let ren = Renderer::cuda();
let scheduler = Scheduler::new(sink, ren);
let result = matching::detect_matmul(&scheduler);
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_select_tensor_core_auto() {
let i = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(0), AxisType::Global);
let j = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(1), AxisType::Global);
let k = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(2), AxisType::Reduce);
let a_val = UOp::native_const(1.0f32);
let b_val = UOp::native_const(2.0f32);
let mul = a_val.try_mul(&b_val).unwrap();
let reduce = mul.reduce(vec![k.clone()].into(), ReduceOp::Add);
let pattern = matching::MatmulPattern {
reduce_op: reduce,
in0: a_val,
in1: b_val,
in0_ranges: vec![i.clone()],
in1_ranges: vec![j.clone()],
red_ranges: vec![k.clone()],
axis_choices: vec![(j, i, k)],
};
let renderer = Renderer::cuda();
let result = selection::select_tensor_core(&pattern, &renderer, -1, 0);
assert!(result.is_ok());
}
#[test]
fn test_select_tensor_core_specific() {
let i = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(0), AxisType::Global);
let j = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(1), AxisType::Global);
let k = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(2), AxisType::Reduce);
let a_val = UOp::native_const(1.0f32);
let b_val = UOp::native_const(2.0f32);
let mul = a_val.try_mul(&b_val).unwrap();
let reduce = mul.reduce(vec![k.clone()].into(), ReduceOp::Add);
let pattern = matching::MatmulPattern {
reduce_op: reduce,
in0: a_val,
in1: b_val,
in0_ranges: vec![i.clone()],
in1_ranges: vec![j.clone()],
red_ranges: vec![k.clone()],
axis_choices: vec![(j, i, k)],
};
let renderer = Renderer::cuda();
let result = selection::select_tensor_core(&pattern, &renderer, 0, 0);
assert!(result.is_ok());
}
#[test]
fn test_select_tensor_core_out_of_bounds() {
let i = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(0), AxisType::Global);
let j = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(1), AxisType::Global);
let k = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(2), AxisType::Reduce);
let a_val = UOp::native_const(1.0f32);
let b_val = UOp::native_const(2.0f32);
let mul = a_val.try_mul(&b_val).unwrap();
let reduce = mul.reduce(vec![k.clone()].into(), ReduceOp::Add);
let pattern = matching::MatmulPattern {
reduce_op: reduce,
in0: a_val,
in1: b_val,
in0_ranges: vec![i.clone()],
in1_ranges: vec![j.clone()],
red_ranges: vec![k.clone()],
axis_choices: vec![(j, i, k)],
};
let renderer = Renderer::cuda();
let result = selection::select_tensor_core(&pattern, &renderer, 9999, 0);
assert!(result.is_err());
}
#[test]
fn test_base_shape() {
use crate::optimizer::renderer::{CUDA_81616, SwizzleAxis};
use morok_dtype::DType;
let tc = CUDA_81616.build(DType::Float16, DType::Float32);
let shape = swizzle::base_shape(&tc);
assert!(!shape.is_empty());
assert!(shape.contains(&SwizzleAxis::Upcast(0)));
assert!(shape.contains(&SwizzleAxis::Local(0)));
assert!(shape.contains(&SwizzleAxis::Reduce(0)));
let upcast_count = shape.iter().filter(|&&a| matches!(a, SwizzleAxis::Upcast(_))).count();
let local_count = shape.iter().filter(|&&a| matches!(a, SwizzleAxis::Local(_))).count();
let reduce_count = shape.iter().filter(|&&a| matches!(a, SwizzleAxis::Reduce(_))).count();
assert_eq!(upcast_count, 2);
assert_eq!(local_count, 5);
assert_eq!(reduce_count, 4);
}
#[test]
fn test_permutes_for_shape() {
use crate::optimizer::renderer::CUDA_81616;
use morok_dtype::DType;
let tc = CUDA_81616.build(DType::Float16, DType::Float32);
let shape = swizzle::base_shape(&tc);
let (perm_a, perm_b) = swizzle::permutes_for_shape(&tc, &shape);
assert!(!perm_a.is_empty());
assert!(!perm_b.is_empty());
for &idx in &perm_a {
assert!(idx < shape.len());
}
for &idx in &perm_b {
assert!(idx < shape.len());
}
}
#[test]
fn test_reduce_axes_count() {
use crate::optimizer::renderer::CUDA_81616;
use morok_dtype::DType;
let tc = CUDA_81616.build(DType::Float16, DType::Float32);
let count = swizzle::get_reduce_axes_count(&tc);
assert_eq!(count, 4);
}
#[test]
fn test_apply_tc_basic() {
let i = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(0), AxisType::Global);
let j = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(1), AxisType::Global);
let k = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(2), AxisType::Reduce);
let a_val = UOp::native_const(1.0f32);
let b_val = UOp::native_const(2.0f32);
let mul = a_val.try_mul(&b_val).unwrap();
let reduce = mul.reduce(vec![k].into(), ReduceOp::Add);
let sink = UOp::sink(vec![reduce, i, j]);
let ren = Renderer::cuda();
let mut scheduler = Scheduler::new(sink, ren);
let result = apply(&mut scheduler, -1, 0, 1);
let _result_ok = result.is_ok() || result.is_err();
}
#[test]
fn test_apply_tc_validation() {
let val = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![val]);
let ren = Renderer::cuda();
let mut scheduler = Scheduler::new(sink, ren);
let result = apply(&mut scheduler, -1, 0, 1);
assert!(result.is_err());
}
#[test]
fn test_apply_tc_invalid_use_tc() {
let i = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(0), AxisType::Global);
let k = UOp::range_axis(UOp::index_const(16), AxisId::Renumbered(1), AxisType::Reduce);
let val = UOp::native_const(1.0f32);
let mul = val.try_mul(&val).unwrap();
let reduce = mul.reduce(vec![k].into(), ReduceOp::Add);
let sink = UOp::sink(vec![reduce, i]);
let ren = Renderer::cuda();
let mut scheduler = Scheduler::new(sink, ren);
let result = apply(&mut scheduler, -1, 0, 3);
assert!(result.is_err());
}
use morok_dtype::DType;
use std::sync::Arc;
fn create_matmul_pattern_for_padding(m: i64, n: i64, k: i64) -> Arc<morok_ir::UOp> {
let m_range = UOp::range_axis(UOp::index_const(m), AxisId::Renumbered(0), AxisType::Global);
let n_range = UOp::range_axis(UOp::index_const(n), AxisId::Renumbered(1), AxisType::Global);
let k_range = UOp::range_axis(UOp::index_const(k), AxisId::Renumbered(2), AxisType::Reduce);
let m_float = m_range.clone().cast(DType::Float32);
let k_float = k_range.clone().cast(DType::Float32);
let n_float = n_range.clone().cast(DType::Float32);
let a_val = m_float.try_add(&k_float).unwrap();
let b_val = k_float.try_add(&n_float).unwrap();
let mul = a_val.try_mul(&b_val).unwrap();
let reduce = mul.reduce(vec![k_range].into(), ReduceOp::Add);
UOp::sink(vec![reduce, m_range, n_range])
}
#[test]
fn test_tc_no_padding_divisible_dims() {
let sink = create_matmul_pattern_for_padding(16, 16, 16);
let ren = Renderer::cuda();
let mut scheduler = Scheduler::new(sink, ren);
let pattern = matching::detect_matmul(&scheduler);
assert!(pattern.is_ok(), "Pattern detection should succeed");
assert!(pattern.unwrap().is_some(), "Matmul pattern should be found");
let result = apply(&mut scheduler, -1, 1, 1);
if let Err(ref e) = result {
let err_msg = format!("{:?}", e);
assert!(!err_msg.contains("not divisible"), "16x16x16 should not fail divisibility check: {}", err_msg);
}
}
#[test]
fn test_tc_rejects_non_divisible_without_tc_opt_2() {
let sink = create_matmul_pattern_for_padding(15, 16, 16);
let ren = Renderer::cuda();
let mut scheduler = Scheduler::new(sink, ren);
let pattern = matching::detect_matmul(&scheduler);
assert!(pattern.is_ok(), "Pattern detection should succeed");
assert!(pattern.unwrap().is_some(), "Matmul pattern should be found");
let result = apply(&mut scheduler, -1, 1, 1);
assert!(result.is_err(), "TC should fail for non-divisible dims with tc_opt=1");
let err_msg = format!("{:?}", result.unwrap_err());
assert!(
err_msg.contains("not divisible") || err_msg.contains("no compatible"),
"Should fail due to divisibility or no compatible TC: {}",
err_msg
);
}
#[test]
fn test_tc_padding_with_tc_opt_2() {
let sink = create_matmul_pattern_for_padding(15, 16, 16);
let ren = Renderer::cuda();
let mut scheduler = Scheduler::new(sink, ren);
let pattern = matching::detect_matmul(&scheduler);
assert!(pattern.is_ok(), "Pattern detection should succeed");
assert!(pattern.unwrap().is_some(), "Matmul pattern should be found");
let result = apply(&mut scheduler, -1, 2, 1);
if let Err(ref e) = result {
let err_msg = format!("{:?}", e);
assert!(!err_msg.contains("not divisible"), "tc_opt=2 should pad instead of rejecting: {}", err_msg);
}
}
#[test]
fn test_tc_padding_rejects_4x_work_increase() {
let sink = create_matmul_pattern_for_padding(4, 16, 16);
let ren = Renderer::cuda();
let mut scheduler = Scheduler::new(sink, ren);
let pattern = matching::detect_matmul(&scheduler);
assert!(pattern.is_ok(), "Pattern detection should succeed");
assert!(pattern.unwrap().is_some(), "Matmul pattern should be found");
let result = apply(&mut scheduler, -1, 2, 1);
assert!(result.is_err(), "Should fail due to 4x work limit or no compatible TC");
}
#[test]
fn test_tc_padding_all_axes() {
let sink = create_matmul_pattern_for_padding(17, 17, 17);
let ren = Renderer::cuda();
let mut scheduler = Scheduler::new(sink, ren);
let pattern = matching::detect_matmul(&scheduler);
assert!(pattern.is_ok(), "Pattern detection should succeed");
assert!(pattern.unwrap().is_some(), "Matmul pattern should be found");
let result = apply(&mut scheduler, -1, 2, 1);
if let Err(ref e) = result {
let err_msg = format!("{:?}", e);
assert!(!err_msg.contains("not divisible"), "tc_opt=2 should pad instead of rejecting: {}", err_msg);
}
}
#[test]
fn test_tc_opt_validation() {
let sink = create_matmul_pattern_for_padding(16, 16, 16);
let ren = Renderer::cuda();
let mut scheduler = Scheduler::new(sink, ren);
let result = apply(&mut scheduler, -1, 3, 1);
assert!(result.is_err(), "tc_opt=3 should be rejected");
let err_msg = format!("{:?}", result.unwrap_err());
assert!(err_msg.contains("tc_opt must be"), "Should fail validation: {}", err_msg);
}
use crate::optimizer::renderer::{APPLE_AMX, SwizzleAxis};
use crate::optimizer::{Opt, apply_opt};
#[test]
fn test_detect_matmul_amx() {
let sink = create_matmul_pattern_for_padding(16, 16, 16);
let ren = Renderer::apple_amx();
let scheduler = Scheduler::new(sink, ren);
let pattern = matching::detect_matmul(&scheduler);
assert!(pattern.is_ok(), "Pattern detection should succeed");
assert!(pattern.unwrap().is_some(), "Matmul pattern should be found for AMX");
}
#[test]
fn test_select_tensor_core_amx() {
let sink = create_matmul_pattern_for_padding(16, 16, 16);
let ren = Renderer::apple_amx();
let scheduler = Scheduler::new(sink, ren);
let pattern = matching::detect_matmul(&scheduler).unwrap().unwrap();
let result = selection::select_tensor_core(&pattern, &scheduler.ren, -1, 0);
assert!(result.is_ok(), "TC selection should not error: {:?}", result.err());
if let Ok(Some(sel)) = result {
let tc = &scheduler.ren.tensor_cores[sel.tc_index];
assert_eq!(tc.dims, (16, 16, 1), "AMX float32 TC should have dims (16, 16, 1)");
assert_eq!(tc.threads, 1, "AMX uses 1 thread (CPU)");
assert_eq!(tc.dtype_in, DType::Float32);
assert_eq!(tc.dtype_out, DType::Float32);
}
}
#[test]
fn test_base_shape_amx() {
let tc = APPLE_AMX.build(DType::Float32, DType::Float32);
let shape = swizzle::base_shape(&tc);
let upcast_count = shape.iter().filter(|&&a| matches!(a, SwizzleAxis::Upcast(_))).count();
let local_count = shape.iter().filter(|&&a| matches!(a, SwizzleAxis::Local(_))).count();
let reduce_count = shape.iter().filter(|&&a| matches!(a, SwizzleAxis::Reduce(_))).count();
assert_eq!(upcast_count, 8, "AMX should have 8 upcast axes");
assert_eq!(local_count, 0, "AMX should have no local axes");
assert_eq!(reduce_count, 0, "AMX K=1 → 0 reduce axes");
}
#[test]
fn test_permutes_amx() {
let tc = APPLE_AMX.build(DType::Float32, DType::Float32);
let shape = swizzle::base_shape(&tc);
let (perm_a, perm_b) = swizzle::permutes_for_shape(&tc, &shape);
assert_eq!(perm_a, (0..shape.len()).collect::<Vec<_>>(), "AMX A permutation should be identity");
let half = shape.len() / 2;
let expected_b: Vec<usize> = (half..shape.len()).chain(0..half).collect();
assert_eq!(perm_b, expected_b, "AMX B permutation should swap halves");
}
#[test]
fn test_reduce_axes_count_amx() {
let tc = APPLE_AMX.build(DType::Float32, DType::Float32);
let count = swizzle::get_reduce_axes_count(&tc);
assert_eq!(count, 0, "AMX K=1 should produce 0 reduce axes");
}
#[test]
fn test_apply_tc_amx() {
let sink = create_matmul_pattern_for_padding(16, 16, 16);
let ren = Renderer::apple_amx();
let mut scheduler = Scheduler::new(sink, ren);
let result = apply(&mut scheduler, -1, 0, 1);
assert!(result.is_ok(), "AMX TC apply should succeed: {:?}", result.err());
let axes = result.unwrap();
for (i, ax) in axes.iter().enumerate() {
assert!(matches!(ax.op(), morok_ir::Op::Range { .. }), "axes[{i}] should be a RANGE");
}
let ast = scheduler.ast();
let has_wmma = ast.toposort().iter().any(|u| matches!(u.op(), morok_ir::Op::Wmma { .. }));
assert!(has_wmma, "AST should contain WMMA after TC apply");
}
#[test]
fn test_group_after_tc_rejected() {
let sink = create_matmul_pattern_for_padding(16, 16, 16);
let ren = Renderer::apple_amx();
let mut scheduler = Scheduler::new(sink, ren);
let tc_opt = Opt::tc(None, -1, 0, 1);
let tc_result = apply_opt(&mut scheduler, &tc_opt, true);
assert!(tc_result.is_ok(), "TC apply should succeed: {:?}", tc_result.err());
let group_opt = Opt::group(0, 2);
let group_result = apply_opt(&mut scheduler, &group_opt, true);
assert!(group_result.is_err(), "GROUP after TC should be rejected");
let err_msg = format!("{:?}", group_result.unwrap_err());
assert!(err_msg.contains("no grouping with tensor cores"), "Error should mention TC guard: {}", err_msg);
}