use std::sync::Arc;
use morok_dtype::{AddrSpace, DType, ScalarDType};
use morok_ir::types::ConstValue;
use morok_ir::{Op, UOp};
use smallvec::SmallVec;
use crate::devectorize::{
bool_storage_patterns, correct_load_store_patterns, devectorize, load_store_folding_patterns,
load_store_indexing_patterns, no_vectorized_alu, pm_render,
};
use crate::rewrite::graph_rewrite;
pub fn apply_devectorize(uop: &Arc<UOp>) -> Arc<UOp> {
let devectorized = devectorize(uop);
graph_rewrite(pm_render(), devectorized, &mut ())
}
pub fn apply_load_store_folding(uop: &Arc<UOp>) -> Arc<UOp> {
graph_rewrite(load_store_folding_patterns(), uop.clone(), &mut ())
}
pub fn apply_correct_load_store(uop: &Arc<UOp>) -> Arc<UOp> {
graph_rewrite(correct_load_store_patterns(), uop.clone(), &mut ())
}
pub fn apply_bool_storage(uop: &Arc<UOp>) -> Arc<UOp> {
graph_rewrite(bool_storage_patterns(), uop.clone(), &mut ())
}
pub fn apply_pm_render(uop: &Arc<UOp>) -> Arc<UOp> {
graph_rewrite(pm_render(), uop.clone(), &mut ())
}
pub fn apply_no_vectorized_alu(uop: &Arc<UOp>) -> Arc<UOp> {
graph_rewrite(no_vectorized_alu(), uop.clone(), &mut ())
}
pub fn apply_vectorize_normalize(uop: &Arc<UOp>) -> Arc<UOp> {
apply_pm_render(uop)
}
pub fn apply_load_store_indexing(uop: &Arc<UOp>) -> Arc<UOp> {
graph_rewrite(load_store_indexing_patterns(), uop.clone(), &mut ())
}
pub fn apply_cast_after(uop: &Arc<UOp>) -> Arc<UOp> {
use crate::devectorize::devectorize_patterns;
graph_rewrite(devectorize_patterns(), uop.clone(), &mut ())
}
pub fn create_buffer(size: usize) -> Arc<UOp> {
create_buffer_typed(size, ScalarDType::Float32)
}
pub fn create_buffer_typed(size: usize, scalar: ScalarDType) -> Arc<UOp> {
let dtype = DType::Scalar(scalar).ptr(Some(size), AddrSpace::Global);
UOp::new_buffer(morok_dtype::DeviceSpec::Cpu, size, dtype)
}
pub fn create_buffer_local(size: usize, scalar: ScalarDType) -> Arc<UOp> {
let dtype = DType::Scalar(scalar).ptr(Some(size), AddrSpace::Local);
UOp::new_buffer(morok_dtype::DeviceSpec::Cpu, size, dtype)
}
pub fn create_bool_buffer(size: usize) -> Arc<UOp> {
create_buffer_typed(size, ScalarDType::Bool)
}
pub fn create_index(buffer: Arc<UOp>, idx: i64) -> Arc<UOp> {
let idx_uop = UOp::const_(DType::Index, ConstValue::Int(idx));
UOp::index().buffer(buffer).indices(vec![idx_uop]).call().unwrap()
}
pub fn create_vector_index_iota(buffer: Arc<UOp>, count: usize) -> Arc<UOp> {
let indices: SmallVec<[Arc<UOp>; 4]> =
(0..count).map(|i| UOp::const_(DType::Index, ConstValue::Int(i as i64))).collect();
let vec_idx = UOp::vectorize(indices);
let idx_dtype = buffer.dtype().base();
let define = buffer_to_define(&buffer);
let buf_vec = define.broadcast(count);
UOp::new(Op::Index { buffer: buf_vec, indices: smallvec::smallvec![vec_idx], gate: None }, DType::Scalar(idx_dtype))
}
fn buffer_to_define(buffer: &Arc<UOp>) -> Arc<UOp> {
static COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
let id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let size = match buffer.dtype() {
DType::Ptr { size: Some(s), .. } => s,
_ => 1024,
};
UOp::param(id, size, buffer.dtype(), None)
}
pub fn create_vector_index_offset(buffer: Arc<UOp>, count: usize, offset: i64) -> Arc<UOp> {
let indices: SmallVec<[Arc<UOp>; 4]> =
(0..count).map(|i| UOp::const_(DType::Index, ConstValue::Int(offset + i as i64))).collect();
let vec_idx = UOp::vectorize(indices);
let idx_dtype = buffer.dtype().base();
let define = buffer_to_define(&buffer);
let buf_vec = define.broadcast(count);
UOp::new(Op::Index { buffer: buf_vec, indices: smallvec::smallvec![vec_idx], gate: None }, DType::Scalar(idx_dtype))
}
pub fn create_vector_index_scaled(buffer: Arc<UOp>, count: usize, scale: i64) -> Arc<UOp> {
let indices: SmallVec<[Arc<UOp>; 4]> =
(0..count).map(|i| UOp::const_(DType::Index, ConstValue::Int(i as i64 * scale))).collect();
let vec_idx = UOp::vectorize(indices);
let idx_dtype = buffer.dtype().base();
let define = buffer_to_define(&buffer);
let buf_vec = define.broadcast(count);
UOp::new(Op::Index { buffer: buf_vec, indices: smallvec::smallvec![vec_idx], gate: None }, DType::Scalar(idx_dtype))
}
pub fn create_vector_index_values(buffer: Arc<UOp>, values: Vec<i64>) -> Arc<UOp> {
let indices: SmallVec<[Arc<UOp>; 4]> =
values.iter().map(|&v| UOp::const_(DType::Index, ConstValue::Int(v))).collect();
let vec_idx = UOp::vectorize(indices);
let idx_dtype = buffer.dtype().base();
let count = values.len();
let define = buffer_to_define(&buffer);
let buf_vec = define.broadcast(count);
UOp::new(Op::Index { buffer: buf_vec, indices: smallvec::smallvec![vec_idx], gate: None }, DType::Scalar(idx_dtype))
}
pub fn create_vector_index_gated(buffer: Arc<UOp>, count: usize, gate: Arc<UOp>) -> Arc<UOp> {
let indices: SmallVec<[Arc<UOp>; 4]> =
(0..count).map(|i| UOp::const_(DType::Index, ConstValue::Int(i as i64))).collect();
let vec_idx = UOp::vectorize(indices);
let idx_dtype = buffer.dtype().base();
let define = buffer_to_define(&buffer);
let buf_vec = define.broadcast(count);
UOp::new(
Op::Index { buffer: buf_vec, indices: smallvec::smallvec![vec_idx], gate: Some(gate) },
DType::Scalar(idx_dtype),
)
}
pub fn create_index_with_range(buffer: Arc<UOp>, axis_id: usize, bound: i64, scale: i64, offset: i64) -> Arc<UOp> {
use morok_ir::{AxisId, AxisType, BinaryOp};
let range = UOp::new(
Op::Range {
end: UOp::const_(DType::Index, ConstValue::Int(bound)),
axis_id: AxisId::Renumbered(axis_id),
axis_type: AxisType::Loop,
deps: smallvec::SmallVec::new(),
},
DType::Index,
);
let scaled = if scale == 1 {
range
} else {
UOp::new(Op::Binary(BinaryOp::Mul, range, UOp::const_(DType::Index, ConstValue::Int(scale))), DType::Index)
};
let idx = if offset == 0 {
scaled
} else {
UOp::new(Op::Binary(BinaryOp::Add, scaled, UOp::const_(DType::Index, ConstValue::Int(offset))), DType::Index)
};
UOp::index().buffer(buffer).indices(vec![idx]).call().unwrap()
}
pub fn create_load(buffer: Arc<UOp>, index: Arc<UOp>) -> Arc<UOp> {
UOp::load().buffer(buffer).index(index).call()
}
pub fn create_store(index: Arc<UOp>, value: Arc<UOp>) -> Arc<UOp> {
index.store(value)
}
pub fn create_vector_load_iota(buffer: Arc<UOp>, count: usize) -> Arc<UOp> {
let index = create_vector_index_iota(buffer.clone(), count);
UOp::load().buffer(buffer).index(index).call()
}
pub fn create_vector_store_iota(buffer: Arc<UOp>, count: usize, value: Arc<UOp>) -> Arc<UOp> {
let index = create_vector_index_iota(buffer, count);
index.store(value)
}
pub fn create_float_const(value: f64) -> Arc<UOp> {
UOp::const_(DType::Float32, ConstValue::Float(value))
}
pub fn create_int_const(value: i64) -> Arc<UOp> {
UOp::const_(DType::Int64, ConstValue::Int(value))
}
pub fn create_bool_const(value: bool) -> Arc<UOp> {
UOp::const_(DType::Bool, ConstValue::Bool(value))
}
pub fn create_vector_float_iota(count: usize) -> Arc<UOp> {
let elements: SmallVec<[Arc<UOp>; 4]> =
(0..count).map(|i| UOp::const_(DType::Float32, ConstValue::Float(i as f64))).collect();
UOp::vectorize(elements)
}
pub fn create_vector_int_iota(count: usize) -> Arc<UOp> {
let elements: SmallVec<[Arc<UOp>; 4]> =
(0..count).map(|i| UOp::const_(DType::Int64, ConstValue::Int(i as i64))).collect();
UOp::vectorize(elements)
}
pub fn create_vector_float_values(values: Vec<f64>) -> Arc<UOp> {
let elements: SmallVec<[Arc<UOp>; 4]> =
values.into_iter().map(|v| UOp::const_(DType::Float32, ConstValue::Float(v))).collect();
UOp::vectorize(elements)
}
pub fn create_vector_int_values(values: Vec<i64>) -> Arc<UOp> {
let elements: SmallVec<[Arc<UOp>; 4]> =
values.into_iter().map(|v| UOp::const_(DType::Int64, ConstValue::Int(v))).collect();
UOp::vectorize(elements)
}
pub fn create_vector_bool(values: Vec<bool>) -> Arc<UOp> {
let elements: SmallVec<[Arc<UOp>; 4]> =
values.into_iter().map(|v| UOp::const_(DType::Bool, ConstValue::Bool(v))).collect();
UOp::vectorize(elements)
}
pub fn assert_is_ptrcat(uop: &Arc<UOp>, expected_count: usize) {
match uop.op() {
Op::PtrCat { sources } => {
assert_eq!(
sources.len(),
expected_count,
"PTRCAT source count mismatch: expected {}, got {}",
expected_count,
sources.len()
);
}
other => panic!("Expected PTRCAT, got {:?}", other),
}
}
pub fn assert_is_cat(uop: &Arc<UOp>, expected_count: usize) {
match uop.op() {
Op::Cat { sources } => {
assert_eq!(
sources.len(),
expected_count,
"CAT source count mismatch: expected {}, got {}",
expected_count,
sources.len()
);
}
other => panic!("Expected CAT, got {:?}", other),
}
}
pub fn assert_is_vectorize(uop: &Arc<UOp>, expected_count: usize) {
match uop.op() {
Op::Vectorize { elements } => {
assert_eq!(
elements.len(),
expected_count,
"VECTORIZE element count mismatch: expected {}, got {}",
expected_count,
elements.len()
);
}
other => panic!("Expected VECTORIZE, got {:?}", other),
}
}
pub fn assert_vcount(uop: &Arc<UOp>, expected: usize) {
assert_eq!(uop.dtype().vcount(), expected, "vcount mismatch: expected {}, got {}", expected, uop.dtype().vcount());
}
pub fn assert_dtype(uop: &Arc<UOp>, expected: DType) {
assert_eq!(uop.dtype(), expected, "dtype mismatch");
}
pub fn assert_base_dtype(uop: &Arc<UOp>, expected: ScalarDType) {
assert_eq!(uop.dtype().base(), expected, "base dtype mismatch");
}
pub fn assert_is_load(uop: &Arc<UOp>) {
assert!(matches!(uop.op(), Op::Load { .. }), "Expected LOAD, got {:?}", uop.op());
}
pub fn assert_is_store(uop: &Arc<UOp>) {
assert!(matches!(uop.op(), Op::Store { .. }), "Expected STORE, got {:?}", uop.op());
}
pub fn assert_is_gep(uop: &Arc<UOp>, expected_indices: &[usize]) {
match uop.op() {
Op::Gep { indices, .. } => {
assert_eq!(
indices, expected_indices,
"GEP indices mismatch: expected {:?}, got {:?}",
expected_indices, indices
);
}
other => panic!("Expected GEP, got {:?}", other),
}
}
pub fn assert_is_cast(uop: &Arc<UOp>) {
assert!(matches!(uop.op(), Op::Cast { .. }), "Expected CAST, got {:?}", uop.op());
}
pub fn assert_is_group(uop: &Arc<UOp>, expected_count: usize) {
match uop.op() {
Op::Group { sources } => {
assert_eq!(
sources.len(),
expected_count,
"GROUP source count mismatch: expected {}, got {}",
expected_count,
sources.len()
);
}
other => panic!("Expected GROUP, got {:?}", other),
}
}
pub fn assert_is_index(uop: &Arc<UOp>) {
assert!(matches!(uop.op(), Op::Index { .. }), "Expected INDEX, got {:?}", uop.op());
}
pub fn count_ops<F>(uop: &Arc<UOp>, predicate: F) -> usize
where
F: Fn(&Arc<UOp>) -> bool,
{
let mut count = 0;
count_ops_recursive(uop, &predicate, &mut count);
count
}
fn count_ops_recursive<F>(uop: &Arc<UOp>, predicate: &F, count: &mut usize)
where
F: Fn(&Arc<UOp>) -> bool,
{
if predicate(uop) {
*count += 1;
}
for child in uop.op().children() {
count_ops_recursive(child, predicate, count);
}
}
pub fn count_loads(uop: &Arc<UOp>) -> usize {
count_ops(uop, |u| matches!(u.op(), Op::Load { .. }))
}
pub fn count_stores(uop: &Arc<UOp>) -> usize {
count_ops(uop, |u| matches!(u.op(), Op::Store { .. }))
}
pub fn count_indices(uop: &Arc<UOp>) -> usize {
count_ops(uop, |u| matches!(u.op(), Op::Index { .. }))
}
pub fn count_ptrcats(uop: &Arc<UOp>) -> usize {
count_ops(uop, |u| matches!(u.op(), Op::PtrCat { .. }))
}
pub fn count_cats(uop: &Arc<UOp>) -> usize {
count_ops(uop, |u| matches!(u.op(), Op::Cat { .. }))
}
pub fn count_vectorizes(uop: &Arc<UOp>) -> usize {
count_ops(uop, |u| matches!(u.op(), Op::Vectorize { .. }))
}
pub fn count_geps(uop: &Arc<UOp>) -> usize {
count_ops(uop, |u| matches!(u.op(), Op::Gep { .. }))
}
pub fn count_casts(uop: &Arc<UOp>) -> usize {
count_ops(uop, |u| matches!(u.op(), Op::Cast { .. }))
}
pub fn unwrap_ptrcat(uop: &Arc<UOp>) -> SmallVec<[Arc<UOp>; 4]> {
match uop.op() {
Op::PtrCat { sources } => sources.clone(),
other => panic!("Expected PTRCAT, got {:?}", other),
}
}
pub fn unwrap_cat(uop: &Arc<UOp>) -> SmallVec<[Arc<UOp>; 4]> {
match uop.op() {
Op::Cat { sources } => sources.clone(),
other => panic!("Expected CAT, got {:?}", other),
}
}
pub fn unwrap_vectorize(uop: &Arc<UOp>) -> SmallVec<[Arc<UOp>; 4]> {
match uop.op() {
Op::Vectorize { elements } => elements.clone(),
other => panic!("Expected VECTORIZE, got {:?}", other),
}
}
pub fn unwrap_load(uop: &Arc<UOp>) -> (Arc<UOp>, Arc<UOp>) {
match uop.op() {
Op::Load { buffer, index, .. } => (buffer.clone(), index.clone()),
other => panic!("Expected LOAD, got {:?}", other),
}
}
pub fn unwrap_store(uop: &Arc<UOp>) -> (Arc<UOp>, Arc<UOp>) {
match uop.op() {
Op::Store { index, value, .. } => (index.clone(), value.clone()),
other => panic!("Expected STORE, got {:?}", other),
}
}
pub fn unwrap_gep(uop: &Arc<UOp>) -> (Arc<UOp>, Vec<usize>) {
match uop.op() {
Op::Gep { vector, indices } => (vector.clone(), indices.clone()),
other => panic!("Expected GEP, got {:?}", other),
}
}
pub fn unwrap_cast(uop: &Arc<UOp>) -> (Arc<UOp>, DType) {
match uop.op() {
Op::Cast { src, dtype } => (src.clone(), dtype.clone()),
other => panic!("Expected CAST, got {:?}", other),
}
}
#[allow(clippy::type_complexity)]
pub fn unwrap_index(uop: &Arc<UOp>) -> (Arc<UOp>, SmallVec<[Arc<UOp>; 4]>, Option<Arc<UOp>>) {
match uop.op() {
Op::Index { buffer, indices, gate } => (buffer.clone(), indices.clone(), gate.clone()),
other => panic!("Expected INDEX, got {:?}", other),
}
}
pub fn unwrap_group(uop: &Arc<UOp>) -> Vec<Arc<UOp>> {
match uop.op() {
Op::Group { sources } => sources.to_vec(),
other => panic!("Expected GROUP, got {:?}", other),
}
}
use morok_ir::{AxisId, AxisType, ReduceOp};
pub fn apply_pm_reduce(uop: &Arc<UOp>) -> Arc<UOp> {
use crate::devectorize::{ReduceContext, pm_reduce};
let mut ctx = ReduceContext::default();
graph_rewrite(&pm_reduce(), uop.clone(), &mut ctx)
}
pub fn apply_gep_movement(uop: &Arc<UOp>) -> Arc<UOp> {
apply_load_store_folding(uop)
}
pub fn create_reduce(src: Arc<UOp>, ranges: Vec<Arc<UOp>>, reduce_op: ReduceOp) -> Arc<UOp> {
src.reduce(ranges.into_iter().collect(), reduce_op)
}
pub fn create_range_loop(end: i64, axis_id: u32) -> Arc<UOp> {
let end_uop = UOp::const_(DType::Index, ConstValue::Int(end));
UOp::range_axis(end_uop, AxisId::Renumbered(axis_id as usize), AxisType::Loop)
}
pub fn create_range_reduce(end: i64, axis_id: u32) -> Arc<UOp> {
let end_uop = UOp::const_(DType::Index, ConstValue::Int(end));
UOp::range_axis(end_uop, AxisId::Renumbered(axis_id as usize), AxisType::Reduce)
}
pub fn create_range_thread(end: i64, axis_id: u32) -> Arc<UOp> {
let end_uop = UOp::const_(DType::Index, ConstValue::Int(end));
UOp::range_axis(end_uop, AxisId::Renumbered(axis_id as usize), AxisType::Thread)
}
pub fn create_range_global(end: i64, axis_id: u32) -> Arc<UOp> {
let end_uop = UOp::const_(DType::Index, ConstValue::Int(end));
UOp::range_axis(end_uop, AxisId::Renumbered(axis_id as usize), AxisType::Global)
}
pub fn create_range_local(end: i64, axis_id: u32) -> Arc<UOp> {
let end_uop = UOp::const_(DType::Index, ConstValue::Int(end));
UOp::range_axis(end_uop, AxisId::Renumbered(axis_id as usize), AxisType::Local)
}
pub fn assert_is_define_reg(uop: &Arc<UOp>) {
assert!(matches!(uop.op(), Op::DefineReg { .. }), "Expected DEFINE_REG, got {:?}", uop.op());
}
pub fn assert_has_after_deps(uop: &Arc<UOp>, count: usize) {
match uop.op() {
Op::After { deps, .. } => {
assert_eq!(deps.len(), count, "Expected {} AFTER deps, got {}", count, deps.len());
}
other => panic!("Expected AFTER, got {:?}", other),
}
}
pub fn assert_is_end(uop: &Arc<UOp>) {
assert!(matches!(uop.op(), Op::End { .. }), "Expected END, got {:?}", uop.op());
}
pub fn assert_is_reduce(uop: &Arc<UOp>) {
assert!(matches!(uop.op(), Op::Reduce { .. }), "Expected REDUCE, got {:?}", uop.op());
}
pub fn unwrap_reduce(uop: &Arc<UOp>) -> (Arc<UOp>, SmallVec<[Arc<UOp>; 4]>, ReduceOp) {
match uop.op() {
Op::Reduce { src, ranges, reduce_op } => (src.clone(), ranges.clone(), *reduce_op),
other => panic!("Expected REDUCE, got {:?}", other),
}
}
pub fn create_gep(vector: Arc<UOp>, indices: Vec<usize>) -> Arc<UOp> {
vector.gep(indices)
}
pub fn create_load_with_gep_index(buffer: Arc<UOp>, index: Arc<UOp>, gep_indices: Vec<usize>) -> Arc<UOp> {
let gep_index = index.gep(gep_indices);
UOp::load().buffer(buffer).index(gep_index).call()
}
pub fn create_store_with_gep_index(
index: Arc<UOp>,
gep_indices: Vec<usize>,
value: Arc<UOp>,
ranges: SmallVec<[Arc<UOp>; 4]>,
) -> Arc<UOp> {
let gep_index = index.gep(gep_indices);
gep_index.store_with_ranges(value, ranges)
}
pub fn compute_inverse_permutation(indices: &[usize]) -> Vec<usize> {
let mut inverse_map: Vec<(usize, usize)> = indices.iter().enumerate().map(|(i, &x)| (x, i)).collect();
inverse_map.sort_by_key(|&(x, _)| x);
inverse_map.iter().map(|&(_, i)| i).collect()
}
pub fn count_reduces(uop: &Arc<UOp>) -> usize {
count_ops(uop, |u| matches!(u.op(), Op::Reduce { .. }))
}
pub fn count_define_regs(uop: &Arc<UOp>) -> usize {
count_ops(uop, |u| matches!(u.op(), Op::DefineReg { .. }))
}
pub fn count_ends(uop: &Arc<UOp>) -> usize {
count_ops(uop, |u| matches!(u.op(), Op::End { .. }))
}
pub fn count_ranges(uop: &Arc<UOp>) -> usize {
count_ops(uop, |u| matches!(u.op(), Op::Range { .. }))
}
pub fn count_ranges_by_type(uop: &Arc<UOp>, target_type: AxisType) -> usize {
count_ops(uop, |u| matches!(u.op(), Op::Range { axis_type, .. } if *axis_type == target_type))
}