use std::sync::Arc;
use morok_dtype::DType;
use morok_ir::types::ConstValue;
use morok_ir::{Op, UOp};
use crate::expand::pre_expand;
use crate::rewrite::graph_rewrite;
pub fn expander_rewrite(uop: &Arc<UOp>) -> Arc<UOp> {
pre_expand(uop)
}
pub fn phase2_only(uop: &Arc<UOp>) -> Arc<UOp> {
let phase2 = crate::expand::pm_pre_expander() + crate::expand::pm_group_for_reduce() + crate::expand::expander();
graph_rewrite(&phase2, uop.clone(), &mut ())
}
pub fn create_unroll_iota(axis_id: usize, count: usize) -> Arc<UOp> {
let values: Vec<ConstValue> = (0..count as i64).map(ConstValue::Int).collect();
let vconst = UOp::vconst(values, DType::Int64);
vconst.unroll_with_dtype(vec![(axis_id, count)], DType::Int64)
}
pub fn create_unroll_scaled(axis_id: usize, count: usize, scale: i64) -> Arc<UOp> {
let values: Vec<ConstValue> = (0..count as i64).map(|i| ConstValue::Int(i * scale)).collect();
let vconst = UOp::vconst(values, DType::Int64);
vconst.unroll_with_dtype(vec![(axis_id, count)], DType::Int64)
}
pub fn create_unroll_values(axis_id: usize, values: Vec<i64>) -> Arc<UOp> {
let const_values: Vec<ConstValue> = values.into_iter().map(ConstValue::Int).collect();
let count = const_values.len();
let vconst = UOp::vconst(const_values, DType::Int64);
vconst.unroll_with_dtype(vec![(axis_id, count)], DType::Int64)
}
pub fn create_unroll_multi_axis(axes: Vec<(usize, usize)>) -> Arc<UOp> {
let total_count: usize = axes.iter().map(|(_, sz)| sz).product();
let values: Vec<ConstValue> = (0..total_count as i64).map(ConstValue::Int).collect();
let vconst = UOp::vconst(values, DType::Int64);
vconst.unroll_with_dtype(axes, DType::Int64)
}
pub fn create_unroll_multi_axis_with_dtype(axes: Vec<(usize, usize)>, dtype: DType) -> Arc<UOp> {
let total_count: usize = axes.iter().map(|(_, sz)| sz).product();
let values: Vec<ConstValue> = (0..total_count as i64).map(ConstValue::Int).collect();
let vconst = UOp::vconst(values, DType::Int64);
vconst.unroll_with_dtype(axes, dtype)
}
pub fn create_vconst_int(values: Vec<i64>) -> Arc<UOp> {
let const_values: Vec<ConstValue> = values.into_iter().map(ConstValue::Int).collect();
UOp::vconst(const_values, DType::Int64)
}
pub fn create_contract(src: Arc<UOp>, axes: Vec<(usize, usize)>) -> Arc<UOp> {
src.contract(axes)
}
pub fn create_contract_void(src: Arc<UOp>, axes: Vec<(usize, usize)>) -> Arc<UOp> {
UOp::new(morok_ir::Op::Contract { src: src.clone(), upcast_ranges: axes }, DType::Void)
}
pub fn assert_is_unroll(uop: &Arc<UOp>, expected_axes: &[(usize, usize)]) {
match uop.op() {
Op::Unroll { unroll_axes, .. } => {
assert_eq!(
unroll_axes.as_slice(),
expected_axes,
"Expected UNROLL axes {:?}, got {:?}",
expected_axes,
unroll_axes
);
}
other => panic!("Expected UNROLL, got {:?}", other),
}
}
pub fn unwrap_unroll(uop: &Arc<UOp>) -> (Arc<UOp>, Vec<(usize, usize)>) {
match uop.op() {
Op::Unroll { src, unroll_axes } => (src.clone(), unroll_axes.clone()),
other => panic!("Expected UNROLL, got {:?}", other),
}
}
pub fn assert_is_vconst(uop: &Arc<UOp>, expected_values: &[i64]) {
match uop.op() {
Op::VConst { values } => {
let actual: Vec<i64> = values
.iter()
.map(|cv| match cv {
ConstValue::Int(i) => *i,
other => panic!("Expected Int, got {:?}", other),
})
.collect();
assert_eq!(actual, expected_values, "VCONST values mismatch");
}
other => panic!("Expected VCONST, got {:?}", other),
}
}
pub fn unwrap_vconst(uop: &Arc<UOp>) -> Vec<i64> {
match uop.op() {
Op::VConst { values } => values
.iter()
.map(|cv| match cv {
ConstValue::Int(i) => *i,
ConstValue::UInt(u) => *u as i64,
other => panic!("Expected Int, got {:?}", other),
})
.collect(),
other => panic!("Expected VCONST, got {:?}", other),
}
}
pub fn unwrap_unroll_vconst(uop: &Arc<UOp>) -> Vec<i64> {
let (src, _) = unwrap_unroll(uop);
unwrap_vconst(&src)
}
pub fn assert_is_gep(uop: &Arc<UOp>, expected_indices: &[usize]) {
match uop.op() {
Op::Gep { indices, .. } => {
assert_eq!(
indices, expected_indices,
"GEP indices mismatch: expected {:?}, got {:?}",
expected_indices, indices
);
}
other => panic!("Expected GEP, got {:?}", other),
}
}
pub fn unwrap_gep(uop: &Arc<UOp>) -> (Arc<UOp>, Vec<usize>) {
match uop.op() {
Op::Gep { vector, indices } => (vector.clone(), indices.clone()),
other => panic!("Expected GEP, got {:?}", other),
}
}
pub fn assert_is_vectorize(uop: &Arc<UOp>, expected_count: usize) {
match uop.op() {
Op::Vectorize { elements } => {
assert_eq!(elements.len(), expected_count, "VECTORIZE element count mismatch");
}
other => panic!("Expected VECTORIZE, got {:?}", other),
}
}
pub fn assert_is_contract(uop: &Arc<UOp>, expected_axes: &[(usize, usize)]) {
match uop.op() {
Op::Contract { upcast_ranges, .. } => {
assert_eq!(upcast_ranges.as_slice(), expected_axes, "CONTRACT axes mismatch");
}
other => panic!("Expected CONTRACT, got {:?}", other),
}
}
pub fn unwrap_contract(uop: &Arc<UOp>) -> (Arc<UOp>, Vec<(usize, usize)>) {
match uop.op() {
Op::Contract { src, upcast_ranges } => (src.clone(), upcast_ranges.clone()),
other => panic!("Expected CONTRACT, got {:?}", other),
}
}
pub fn assert_dtype(uop: &Arc<UOp>, expected: DType) {
assert_eq!(uop.dtype(), expected, "dtype mismatch");
}
pub fn assert_vcount(uop: &Arc<UOp>, expected: usize) {
assert_eq!(uop.dtype().vcount(), expected, "vcount mismatch: expected {}, got {}", expected, uop.dtype().vcount());
}
use morok_ir::BufferizeOpts;
pub fn create_bufferize(compute: Arc<UOp>, ranges: Vec<Arc<UOp>>, opts: BufferizeOpts) -> Arc<UOp> {
UOp::new(Op::Bufferize { compute, ranges: ranges.into_iter().collect(), opts }, DType::Void)
}
pub fn create_bufferize_global(compute: Arc<UOp>, ranges: Vec<Arc<UOp>>) -> Arc<UOp> {
use morok_dtype::DeviceSpec;
create_bufferize(compute, ranges, BufferizeOpts::new(DeviceSpec::Cpu))
}
pub fn assert_is_bufferize(uop: &Arc<UOp>) {
assert!(matches!(uop.op(), Op::Bufferize { .. }), "Expected BUFFERIZE, got {:?}", uop.op());
}
pub fn assert_bufferize_compute_is_contract(uop: &Arc<UOp>) {
match uop.op() {
Op::Bufferize { compute, .. } => {
assert!(
matches!(compute.op(), Op::Contract { .. }),
"Expected BUFFERIZE.compute to be CONTRACT, got {:?}",
compute.op()
);
}
other => panic!("Expected BUFFERIZE, got {:?}", other),
}
}
pub fn unwrap_bufferize(uop: &Arc<UOp>) -> (Arc<UOp>, smallvec::SmallVec<[Arc<UOp>; 4]>, BufferizeOpts) {
match uop.op() {
Op::Bufferize { compute, ranges, opts } => (compute.clone(), ranges.clone(), opts.clone()),
other => panic!("Expected BUFFERIZE, got {:?}", other),
}
}
pub fn create_end(computation: Arc<UOp>, ranges: Vec<Arc<UOp>>) -> Arc<UOp> {
computation.end(ranges.into_iter().collect())
}
pub fn assert_is_end(uop: &Arc<UOp>) {
assert!(matches!(uop.op(), Op::End { .. }), "Expected END, got {:?}", uop.op());
}
pub fn assert_is_end_with_contract(uop: &Arc<UOp>) {
match uop.op() {
Op::End { computation, .. } => {
assert!(
matches!(computation.op(), Op::Contract { .. }),
"Expected END.computation to be CONTRACT, got {:?}",
computation.op()
);
}
other => panic!("Expected END, got {:?}", other),
}
}
pub fn assert_end_ranges_count(uop: &Arc<UOp>, expected: usize) {
match uop.op() {
Op::End { ranges, .. } => {
assert_eq!(
ranges.len(),
expected,
"END ranges count mismatch: expected {}, got {}",
expected,
ranges.len()
);
}
other => panic!("Expected END, got {:?}", other),
}
}
pub fn unwrap_end(uop: &Arc<UOp>) -> (Arc<UOp>, smallvec::SmallVec<[Arc<UOp>; 4]>) {
match uop.op() {
Op::End { computation, ranges } => (computation.clone(), ranges.clone()),
other => panic!("Expected END, got {:?}", other),
}
}
pub fn count_ops<F>(uop: &Arc<UOp>, predicate: F) -> usize
where
F: Fn(&Arc<UOp>) -> bool,
{
let mut count = 0;
count_ops_recursive(uop, &predicate, &mut count);
count
}
fn count_ops_recursive<F>(uop: &Arc<UOp>, predicate: &F, count: &mut usize)
where
F: Fn(&Arc<UOp>) -> bool,
{
if predicate(uop) {
*count += 1;
}
for child in uop.op().children() {
count_ops_recursive(child, predicate, count);
}
}
pub fn count_bufferizes(uop: &Arc<UOp>) -> usize {
count_ops(uop, |u| matches!(u.op(), Op::Bufferize { .. }))
}
pub fn count_ends(uop: &Arc<UOp>) -> usize {
count_ops(uop, |u| matches!(u.op(), Op::End { .. }))
}
pub fn count_contracts(uop: &Arc<UOp>) -> usize {
count_ops(uop, |u| matches!(u.op(), Op::Contract { .. }))
}
pub fn count_unrolls(uop: &Arc<UOp>) -> usize {
count_ops(uop, |u| matches!(u.op(), Op::Unroll { .. }))
}
pub fn extract_result_values(uop: &Arc<UOp>) -> Vec<i64> {
match uop.op() {
Op::VConst { values } => extract_const_values(values),
Op::Unroll { src, .. } => extract_result_values(src),
Op::Gep { vector, indices } => {
let src_values = extract_result_values(vector);
indices.iter().map(|&i| src_values[i]).collect()
}
Op::Vectorize { elements } => {
elements
.iter()
.map(|e| match e.op() {
Op::Const(cv) => match cv.0 {
ConstValue::Int(i) => i,
ConstValue::UInt(u) => u as i64,
_ => panic!("Expected integer constant in VECTORIZE"),
},
_ => {
let vals = extract_result_values(e);
assert_eq!(vals.len(), 1, "Expected scalar in VECTORIZE element");
vals[0]
}
})
.collect()
}
Op::Binary(op, lhs, rhs) => {
let lhs_vals = extract_result_values(lhs);
let rhs_vals = extract_result_values(rhs);
let (lhs_vals, rhs_vals) = match (lhs_vals.len(), rhs_vals.len()) {
(1, n) => (vec![lhs_vals[0]; n], rhs_vals),
(n, 1) => (lhs_vals, vec![rhs_vals[0]; n]),
(a, b) if a == b => (lhs_vals, rhs_vals),
(a, b) => panic!("Mismatched vector lengths: {} vs {}", a, b),
};
lhs_vals.iter().zip(rhs_vals.iter()).map(|(&l, &r)| eval_binary_i64(*op, l, r)).collect()
}
Op::Const(cv) => vec![match cv.0 {
ConstValue::Int(i) => i,
ConstValue::UInt(u) => u as i64,
_ => panic!("Expected integer constant"),
}],
_ => panic!("Cannot extract values from {:?}", uop.op().as_ref()),
}
}
fn eval_binary_i64(op: morok_ir::types::BinaryOp, lhs: i64, rhs: i64) -> i64 {
use morok_ir::types::BinaryOp;
match op {
BinaryOp::Add => lhs.wrapping_add(rhs),
BinaryOp::Sub => lhs.wrapping_sub(rhs),
BinaryOp::Mul => lhs.wrapping_mul(rhs),
BinaryOp::Idiv => lhs / rhs,
BinaryOp::Mod => lhs % rhs,
BinaryOp::Max => lhs.max(rhs),
BinaryOp::And => lhs & rhs,
BinaryOp::Or => lhs | rhs,
BinaryOp::Xor => lhs ^ rhs,
BinaryOp::Shl => lhs << rhs,
BinaryOp::Shr => lhs >> rhs,
_ => panic!("Unsupported binary op for i64 eval: {:?}", op),
}
}
fn extract_const_values(values: &[ConstValue]) -> Vec<i64> {
values
.iter()
.map(|v| match v {
ConstValue::Int(i) => *i,
ConstValue::UInt(u) => *u as i64,
other => panic!("Expected integer in VCONST, got {:?}", other),
})
.collect()
}
pub fn assert_result_values(uop: &Arc<UOp>, expected: &[i64]) {
let actual = extract_result_values(uop);
assert_eq!(actual, expected, "Result values mismatch");
}
pub fn try_extract_values(uop: &Arc<UOp>) -> Option<Vec<i64>> {
match uop.op() {
Op::VConst { values } => Some(extract_const_values(values)),
Op::Unroll { src, .. } => try_extract_values(src),
Op::Gep { vector, indices } => {
let src_values = try_extract_values(vector)?;
Some(indices.iter().map(|&i| src_values[i]).collect())
}
Op::Vectorize { elements } => {
let mut values = Vec::with_capacity(elements.len());
for e in elements.iter() {
if let Op::Const(cv) = e.op() {
match cv.0 {
ConstValue::Int(i) => values.push(i),
ConstValue::UInt(u) => values.push(u as i64),
_ => return None,
}
} else {
return None;
}
}
Some(values)
}
_ => None,
}
}