use morok_dtype::DType;
use morok_ir::{AxisId, AxisType, ConstValue, ReduceOp, UOp};
use smallvec::SmallVec;
use crate::llvm::text::render;
#[test]
fn test_range_end_basic() {
let end = UOp::const_(DType::Index, ConstValue::Int(10));
let range = UOp::range_axis(end, AxisId::Renumbered(0), AxisType::Loop);
let noop = UOp::noop();
let ranges: SmallVec<[_; 4]> = smallvec::smallvec![range];
let end_op = noop.end(ranges);
let sink = UOp::sink(vec![end_op]);
let result = render(&sink, Some("test_loop"));
if let Err(ref e) = result {
eprintln!("Codegen failed: {:?}", e);
}
assert!(result.is_ok(), "Codegen failed: {:?}", result.err());
let kernel = result.unwrap();
let ir = &kernel.code;
assert!(ir.contains("loop_entry_"), "Missing entry block:\n{}", ir);
assert!(ir.contains("loop_latch_"), "Missing latch block:\n{}", ir);
assert!(ir.contains("loop_body_"), "Missing body block:\n{}", ir);
assert!(ir.contains("loop_footer_"), "Missing footer block:\n{}", ir);
assert!(ir.contains("loop_exit_"), "Missing exit block:\n{}", ir);
assert!(ir.contains("phi i64"), "Missing PHI node:\n{}", ir);
}
#[test]
fn test_reduce_add_basic() {
let const_val = UOp::const_(DType::Float32, ConstValue::Float(5.0));
let range =
UOp::range_axis(UOp::const_(DType::Index, ConstValue::Int(10)), AxisId::Renumbered(0), AxisType::Reduce);
let reduce = const_val.reduce(smallvec::smallvec![range.clone()], ReduceOp::Add);
let ranges: SmallVec<[_; 4]> = smallvec::smallvec![range];
let end_op = reduce.end(ranges);
let sink = UOp::sink(vec![end_op]);
let result = render(&sink, Some("test_reduce_add"));
if let Err(ref e) = result {
eprintln!("Codegen failed: {:?}", e);
}
assert!(result.is_ok(), "Codegen failed: {:?}", result.err());
let kernel = result.unwrap();
let ir = &kernel.code;
assert!(ir.contains("loop_entry_"), "Missing loop entry block:\n{}", ir);
assert!(ir.contains("loop_latch_"), "Missing loop latch block:\n{}", ir);
assert!(ir.contains("loop_exit_"), "Missing loop exit block:\n{}", ir);
assert!(ir.contains("alloca float"), "Missing reduce accumulator alloca:\n{}", ir);
assert!(ir.contains("fadd"), "Missing fadd instruction:\n{}", ir);
}
#[test]
fn test_reduce_max() {
let const_val = UOp::const_(DType::Float32, ConstValue::Float(3.0));
let range = UOp::range_axis(UOp::const_(DType::Index, ConstValue::Int(5)), AxisId::Renumbered(0), AxisType::Reduce);
let reduce = const_val.reduce(smallvec::smallvec![range.clone()], ReduceOp::Max);
let ranges: SmallVec<[_; 4]> = smallvec::smallvec![range];
let end_op = reduce.end(ranges);
let sink = UOp::sink(vec![end_op]);
let result = render(&sink, Some("test_reduce_max"));
assert!(result.is_ok(), "Codegen failed: {:?}", result.err());
let kernel = result.unwrap();
let ir = &kernel.code;
assert!(ir.contains("llvm.maxnum.f") || ir.contains("maxnum"), "Missing maxnum intrinsic:\n{}", ir);
}
#[test]
fn test_reduce_empty_ranges() {
let const_val = UOp::const_(DType::Float32, ConstValue::Float(42.0));
let reduce = const_val.reduce(smallvec::smallvec![], ReduceOp::Add);
let sink = UOp::sink(vec![reduce]);
let result = render(&sink, Some("test_reduce_empty"));
assert!(result.is_ok(), "Codegen failed: {:?}", result.err());
}