use std::sync::Arc;
use morok_dtype::DType;
use morok_ir::{Op, ReduceOp, UOp};
use smallvec::smallvec;
use crate::rangeify::kernel::run_kernel_split_pipeline;
fn create_buffer(size: usize) -> Arc<UOp> {
UOp::new_buffer(morok_dtype::DeviceSpec::Cpu, size, DType::Float32)
}
#[test]
fn test_binop_fusion_basic() {
let a = create_buffer(100);
let b = create_buffer(100);
let add = a.try_add(&b).unwrap();
let sink = UOp::sink(vec![add]);
let (result, _ctx) = run_kernel_split_pipeline(sink);
assert!(!result.op().sources().is_empty() || matches!(result.op(), Op::Sink { .. } | Op::Noop));
}
#[test]
fn test_binop_chain_fusion() {
let a = create_buffer(100);
let b = create_buffer(100);
let c = create_buffer(100);
let add1 = a.try_add(&b).unwrap();
let add2 = add1.try_add(&c).unwrap();
let sink = UOp::sink(vec![add2]);
let (result, _ctx) = run_kernel_split_pipeline(sink);
assert!(!result.toposort().is_empty());
}
#[test]
fn test_binop_reshape_fusion() {
let a = create_buffer(100);
let b = create_buffer(100);
let add = a.try_add(&b).unwrap();
let new_shape = UOp::vectorize(smallvec![UOp::index_const(10), UOp::index_const(10)]);
let reshaped = UOp::new(Op::Reshape { src: add, new_shape }, DType::Float32);
let sink = UOp::sink(vec![reshaped]);
let (result, _ctx) = run_kernel_split_pipeline(sink);
assert!(!result.toposort().is_empty());
}
#[test]
fn test_binop_permute_fusion() {
let a = create_buffer(100);
let b = create_buffer(100);
let add = a.try_add(&b).unwrap();
let new_shape = UOp::vectorize(smallvec![UOp::index_const(10), UOp::index_const(10)]);
let reshaped = UOp::new(Op::Reshape { src: add, new_shape }, DType::Float32);
let permuted = UOp::new(Op::Permute { src: reshaped, axes: vec![1, 0] }, DType::Float32);
let sink = UOp::sink(vec![permuted]);
let (result, _ctx) = run_kernel_split_pipeline(sink);
assert!(!result.toposort().is_empty());
}
#[test]
fn test_reduce_fusion_basic() {
let a = create_buffer(100);
let new_shape = UOp::vectorize(smallvec![UOp::index_const(10), UOp::index_const(10)]);
let reshaped = UOp::new(Op::Reshape { src: a, new_shape }, DType::Float32);
let reduced = UOp::new(Op::ReduceAxis { src: reshaped, axes: vec![1], reduce_op: ReduceOp::Add }, DType::Float32);
let sink = UOp::sink(vec![reduced]);
let (result, _ctx) = run_kernel_split_pipeline(sink);
assert!(!result.toposort().is_empty());
}
#[test]
fn test_reduce_binop_fusion() {
let a = create_buffer(100);
let b = create_buffer(100);
let add = a.try_add(&b).unwrap();
let new_shape = UOp::vectorize(smallvec![UOp::index_const(10), UOp::index_const(10)]);
let reshaped = UOp::new(Op::Reshape { src: add, new_shape }, DType::Float32);
let reduced = UOp::new(Op::ReduceAxis { src: reshaped, axes: vec![1], reduce_op: ReduceOp::Add }, DType::Float32);
let sink = UOp::sink(vec![reduced]);
let (result, _ctx) = run_kernel_split_pipeline(sink);
assert!(!result.toposort().is_empty());
}
#[test]
fn test_contiguous_forces_realization() {
let a = create_buffer(100);
let contiguous = UOp::new(Op::Contiguous { src: a, opts: smallvec::smallvec![] }, DType::Float32);
let sink = UOp::sink(vec![contiguous]);
let (result, _ctx) = run_kernel_split_pipeline(sink);
assert!(!result.toposort().is_empty());
}
#[test]
fn test_multiple_outputs_same_input() {
let a = create_buffer(100);
let b = create_buffer(100);
let c = create_buffer(100);
let d = create_buffer(100);
let add = a.try_add(&b).unwrap();
let mul1 = add.try_mul(&c).unwrap();
let mul2 = add.try_mul(&d).unwrap();
let sink = UOp::sink(vec![mul1, mul2]);
let (result, _ctx) = run_kernel_split_pipeline(sink);
assert!(!result.toposort().is_empty());
}
#[test]
fn test_empty_sink() {
let sink = UOp::sink(vec![]);
let (result, _ctx) = run_kernel_split_pipeline(sink);
let _ = result;
}
#[test]
fn test_single_constant() {
let c = UOp::native_const(1.0f32);
let sink = UOp::sink(vec![c]);
let (result, _ctx) = run_kernel_split_pipeline(sink);
assert!(!result.toposort().is_empty());
}