use std::sync::Arc;
use morok_dtype::DType;
use morok_ir::types::ConstValue;
use morok_ir::{AxisId, AxisType, Op, ReduceOp, UOp};
use smallvec::smallvec;
use crate::expand::pre_expand;
use crate::optimizer::Renderer;
use crate::optimizer::Scheduler;
fn create_simple_reduce(size: usize, axis_id: usize) -> Arc<UOp> {
let end = UOp::const_(DType::Index, ConstValue::Int(size as i64));
let range = UOp::range_axis(end, AxisId::Renumbered(axis_id), AxisType::Reduce);
let src = UOp::const_(DType::Float32, ConstValue::Float(0.0));
src.reduce(smallvec![range], ReduceOp::Add)
}
#[test]
fn test_scheduler_shift_to_integration() {
let reduce = create_simple_reduce(16, 0);
let index = UOp::index_const(0);
let loop_end = UOp::const_(DType::Index, ConstValue::Int(1));
let loop_range = UOp::range_axis(loop_end, AxisId::Renumbered(1), AxisType::Loop);
let store = index.store_with_ranges(reduce.clone(), smallvec![loop_range.clone()]);
let ast = UOp::sink(vec![store]);
let renderer = Renderer::cpu();
let mut scheduler = Scheduler::new(ast.clone(), renderer);
let reduce_range =
scheduler.rngs().iter().find(|r| matches!(r.op(), Op::Range { axis_type: AxisType::Reduce, .. })).cloned();
if let Some(rng) = reduce_range {
let result = scheduler.shift_to(rng, 4, AxisType::Unroll, false, None);
assert!(result.is_ok(), "shift_to should succeed");
let optimized = scheduler.get_optimized_ast(None);
let expanded = pre_expand(&optimized);
let mut found_reduce = false;
for node in expanded.toposort() {
if let Op::Reduce { src, ranges, .. } = node.op() {
found_reduce = true;
assert!(
matches!(src.op(), Op::Contract { .. } | Op::Vectorize { .. } | Op::Const(_) | Op::VConst { .. }),
"REDUCE.src should be CONTRACT/VECTORIZE/Const/VConst after expansion, got {:?}",
src.op()
);
for range in ranges.iter() {
assert!(
!matches!(range.op(), Op::Binary(..)),
"REDUCE.ranges should not contain Binary after expansion"
);
}
}
}
assert!(found_reduce, "Should find REDUCE in expanded AST");
}
}