use morok_ir::{AxisId, AxisType, Op, ReduceOp, UOp};
use std::sync::Arc;
use crate::optimizer::Scheduler;
pub fn create_simple_reduce(size: i64, reduce_op: ReduceOp) -> Arc<UOp> {
use smallvec::smallvec;
let size_uop = UOp::native_const(size as i32 as i64);
let range = UOp::range_axis(size_uop, AxisId::Renumbered(0), AxisType::Reduce);
let const_val = UOp::native_const(1.0f32);
let reduce = const_val.reduce(smallvec![range], reduce_op);
UOp::sink(vec![reduce])
}
pub fn create_reduce_with_globals(global_sizes: &[i64], reduce_size: i64, reduce_op: ReduceOp) -> Arc<UOp> {
use smallvec::smallvec;
let mut all_axes = Vec::new();
for (i, &size) in global_sizes.iter().enumerate() {
let size_uop = UOp::native_const(size as i32 as i64);
let axis = UOp::range_axis(size_uop, AxisId::Renumbered(i), AxisType::Global);
all_axes.push(axis);
}
let reduce_size_uop = UOp::native_const(reduce_size as i32 as i64);
let reduce_axis = UOp::range_axis(reduce_size_uop, AxisId::Renumbered(global_sizes.len()), AxisType::Reduce);
let const_val = UOp::native_const(1.0f32);
let reduce = const_val.reduce(smallvec![reduce_axis], reduce_op);
all_axes.insert(0, reduce);
UOp::sink(all_axes)
}
pub fn create_matmul_pattern(m: i64, n: i64, k: i64) -> Arc<UOp> {
use smallvec::smallvec;
let m_uop = UOp::native_const(m as i32 as i64);
let n_uop = UOp::native_const(n as i32 as i64);
let k_uop = UOp::native_const(k as i32 as i64);
let m_range = UOp::range_axis(m_uop, AxisId::Renumbered(0), AxisType::Global);
let n_range = UOp::range_axis(n_uop, AxisId::Renumbered(1), AxisType::Global);
let k_range = UOp::range_axis(k_uop, AxisId::Renumbered(2), AxisType::Global);
let add_expr = m_range.try_add(&k_range).expect("ADD should succeed with same dtype");
let reduce = add_expr.reduce(smallvec![k_range], ReduceOp::Add);
UOp::sink(vec![reduce, m_range, n_range])
}
pub fn create_double_reduce(size1: i64, size2: i64, reduce_op: ReduceOp) -> Arc<UOp> {
use smallvec::smallvec;
let size1_uop = UOp::native_const(size1 as i32 as i64);
let size2_uop = UOp::native_const(size2 as i32 as i64);
let range1 = UOp::range_axis(size1_uop, AxisId::Renumbered(0), AxisType::Reduce);
let range2 = UOp::range_axis(size2_uop, AxisId::Renumbered(1), AxisType::Reduce);
let const_val = UOp::native_const(1.0f32);
let reduce = const_val.reduce(smallvec![range1, range2], reduce_op);
UOp::sink(vec![reduce])
}
pub fn create_double_reduce_with_globals(global_sizes: &[i64], reduce_sizes: &[i64], reduce_op: ReduceOp) -> Arc<UOp> {
use smallvec::smallvec;
assert_eq!(global_sizes.len(), 2, "Expected 2 global dimensions for double reduce");
assert_eq!(reduce_sizes.len(), 2, "Expected 2 reduce dimensions for double reduce");
let mut all_axes = Vec::new();
let mut axis_id = 0;
for &size in global_sizes {
let size_uop = UOp::native_const(size as i32 as i64);
let axis = UOp::range_axis(size_uop, AxisId::Renumbered(axis_id), AxisType::Global);
all_axes.push(axis);
axis_id += 1;
}
let mut reduce_axes = smallvec![];
for &size in reduce_sizes {
let size_uop = UOp::native_const(size as i32 as i64);
let axis = UOp::range_axis(size_uop, AxisId::Renumbered(axis_id), AxisType::Reduce);
reduce_axes.push(axis);
axis_id += 1;
}
let const_val = UOp::native_const(1.0f32);
let reduce = const_val.reduce(reduce_axes, reduce_op);
all_axes.insert(0, reduce);
UOp::sink(all_axes)
}
pub fn create_elementwise_pattern(sizes: &[i64]) -> Arc<UOp> {
let const_val = UOp::native_const(1.0f32);
let mut ops = vec![const_val];
for (axis_id, &size) in sizes.iter().enumerate() {
let size_uop = UOp::native_const(size as i32 as i64);
let range = UOp::range_axis(size_uop, AxisId::Renumbered(axis_id), AxisType::Global);
ops.push(range);
}
UOp::sink(ops)
}
pub fn assert_axes_equal(scheduler: &Scheduler, expected: &[AxisType]) {
let actual: Vec<AxisType> = scheduler
.rngs()
.iter()
.map(|r| {
if let Op::Range { axis_type, .. } = r.op() {
*axis_type
} else {
panic!("Expected Range operation");
}
})
.collect();
assert_eq!(actual.len(), expected.len(), "Expected {} axes, got {}: {:?}", expected.len(), actual.len(), actual);
for (i, (actual_type, expected_type)) in actual.iter().zip(expected.iter()).enumerate() {
assert_eq!(
actual_type, expected_type,
"Axis {} type mismatch: expected {:?}, got {:?}",
i, expected_type, actual_type
);
}
}
pub fn assert_shape_equal(scheduler: &Scheduler, expected: &[i64]) {
let actual = scheduler.full_shape();
assert_eq!(
actual.len(),
expected.len(),
"Expected {} dimensions, got {}: {:?}",
expected.len(),
actual.len(),
actual
);
for (i, (&actual_size, &expected_size)) in actual.iter().zip(expected.iter()).enumerate() {
if expected_size != -1 {
assert_eq!(
actual_size, expected_size,
"Dimension {} size mismatch: expected {}, got {}",
i, expected_size, actual_size
);
}
}
}
pub fn assert_axis_count(scheduler: &Scheduler, axis_type: AxisType, expected_count: usize) {
let actual_count = scheduler.axes_of(&[axis_type]).len();
assert_eq!(actual_count, expected_count, "Expected {} {:?} axes, got {}", expected_count, axis_type, actual_count);
}
#[allow(dead_code)]
pub fn assert_opt_succeeds<'a>(scheduler: &'a mut Scheduler, opt: &crate::optimizer::Opt) -> &'a mut Scheduler {
crate::optimizer::apply_opt(scheduler, opt, true)
.unwrap_or_else(|e| panic!("Expected optimization {:?} to succeed, but got error: {:?}", opt, e));
scheduler
}
#[allow(dead_code)]
pub fn assert_opt_fails(scheduler: &mut Scheduler, opt: &crate::optimizer::Opt) {
let result = crate::optimizer::apply_opt(scheduler, opt, true);
assert!(result.is_err(), "Expected optimization {:?} to fail, but it succeeded", opt);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::optimizer::{Opt, Renderer};
#[test]
fn test_create_simple_reduce() {
let reduce = create_simple_reduce(32, ReduceOp::Add);
assert!(matches!(reduce.op(), Op::Sink { .. }));
}
#[test]
fn test_create_matmul_pattern() {
let matmul = create_matmul_pattern(16, 16, 16);
assert!(matches!(matmul.op(), Op::Sink { .. }));
}
#[test]
fn test_create_double_reduce() {
let reduce = create_double_reduce(8, 8, ReduceOp::Add);
assert!(matches!(reduce.op(), Op::Sink { .. }));
}
#[test]
fn test_create_elementwise_pattern() {
let elem = create_elementwise_pattern(&[10, 20, 30]);
assert!(matches!(elem.op(), Op::Sink { .. }));
}
#[test]
fn test_assert_axes_equal() {
let reduce = create_simple_reduce(16, ReduceOp::Add);
let renderer = Renderer::cpu();
let scheduler = Scheduler::new(reduce, renderer);
assert_axes_equal(&scheduler, &[AxisType::Reduce]);
}
#[test]
fn test_assert_shape_equal() {
let reduce = create_simple_reduce(16, ReduceOp::Add);
let renderer = Renderer::cpu();
let scheduler = Scheduler::new(reduce, renderer);
assert_shape_equal(&scheduler, &[16]);
}
#[test]
fn test_assert_axis_count() {
let elem = create_elementwise_pattern(&[16]);
let renderer = Renderer::cpu();
let mut scheduler = Scheduler::new(elem, renderer);
assert_axis_count(&scheduler, AxisType::Global, 1);
assert_axis_count(&scheduler, AxisType::Upcast, 0);
crate::optimizer::apply_opt(&mut scheduler, &Opt::upcast(0, 4), true).unwrap();
assert_axis_count(&scheduler, AxisType::Global, 1);
assert_axis_count(&scheduler, AxisType::Upcast, 1);
}
}