use std::sync::Arc;
use morok_ir::{BinaryOp, BufferizeOpts, ConstValue, DType, Op, UOp};
use crate::rangeify::indexing::{get_const_value, is_const, is_identity_value, is_zero_value};
pub fn count_ops<F>(uop: &Arc<UOp>, predicate: F) -> usize
where
F: Fn(&Op) -> bool + Copy,
{
let mut count = if predicate(uop.op()) { 1 } else { 0 };
for src in uop.op().sources() {
count += count_ops(&src, predicate);
}
count
}
pub fn count_kernels(uop: &Arc<UOp>) -> usize {
count_ops(uop, |op| matches!(op, Op::Kernel { .. }))
}
pub fn extract_kernel(uop: &Arc<UOp>) -> Option<Arc<UOp>> {
match uop.op() {
Op::Kernel { .. } => Some(uop.clone()),
Op::After { deps, .. } => {
for dep in deps.iter() {
if let Op::End { computation, .. } = dep.op()
&& matches!(computation.op(), Op::Kernel { .. })
{
return Some(computation.clone());
}
if matches!(dep.op(), Op::Kernel { .. }) {
return Some(dep.clone());
}
}
None
}
Op::Sink { sources } => {
for src in sources.iter() {
if let Some(kernel) = extract_kernel(src) {
return Some(kernel);
}
}
None
}
Op::End { computation, .. } if matches!(computation.op(), Op::Kernel { .. }) => Some(computation.clone()),
_ => None,
}
}
pub fn count_codegen_params(uop: &Arc<UOp>) -> usize {
count_ops(uop, |op| matches!(op, Op::Param { device: None, .. }))
}
pub fn count_define_locals(uop: &Arc<UOp>) -> usize {
count_ops(uop, |op| matches!(op, Op::DefineLocal(_)))
}
pub fn count_stores(uop: &Arc<UOp>) -> usize {
count_ops(uop, |op| matches!(op, Op::Store { .. }))
}
pub fn count_ends(uop: &Arc<UOp>) -> usize {
count_ops(uop, |op| matches!(op, Op::End { .. }))
}
pub fn count_bufferizes(uop: &Arc<UOp>) -> usize {
count_ops(uop, |op| matches!(op, Op::Bufferize { .. }))
}
pub fn create_const(val: i64) -> Arc<UOp> {
UOp::index_const(val)
}
pub fn create_range(end: i64, axis_id: usize) -> Arc<UOp> {
UOp::range_const(end, axis_id)
}
pub fn create_range_symbolic(end: Arc<UOp>, axis_id: usize) -> Arc<UOp> {
UOp::range(end, axis_id)
}
pub fn create_bufferize(compute: Arc<UOp>, ranges: Vec<Arc<UOp>>) -> Arc<UOp> {
UOp::bufferize_global(compute, ranges)
}
pub fn create_bufferize_opts(compute: Arc<UOp>, ranges: Vec<Arc<UOp>>, opts: BufferizeOpts) -> Arc<UOp> {
UOp::bufferize(compute, ranges, opts)
}
#[test]
fn test_is_identity_value() {
assert!(is_identity_value(&ConstValue::Int(0), &BinaryOp::Add, false));
assert!(is_identity_value(&ConstValue::Int(0), &BinaryOp::Add, true));
assert!(is_identity_value(&ConstValue::Float(0.0), &BinaryOp::Add, false));
assert!(is_identity_value(&ConstValue::Int(1), &BinaryOp::Mul, false));
assert!(is_identity_value(&ConstValue::Int(1), &BinaryOp::Mul, true));
assert!(is_identity_value(&ConstValue::Float(1.0), &BinaryOp::Mul, false));
assert!(!is_identity_value(&ConstValue::Int(0), &BinaryOp::Sub, false));
assert!(is_identity_value(&ConstValue::Int(0), &BinaryOp::Sub, true));
assert!(!is_identity_value(&ConstValue::Int(1), &BinaryOp::Idiv, false));
assert!(is_identity_value(&ConstValue::Int(1), &BinaryOp::Idiv, true));
assert!(!is_identity_value(&ConstValue::Int(2), &BinaryOp::Add, false));
assert!(!is_identity_value(&ConstValue::Int(0), &BinaryOp::Mul, false));
}
#[test]
fn test_is_zero_value() {
assert!(is_zero_value(&ConstValue::Int(0), &BinaryOp::Mul));
assert!(is_zero_value(&ConstValue::Float(0.0), &BinaryOp::Mul));
assert!(is_zero_value(&ConstValue::Int(0), &BinaryOp::And));
assert!(!is_zero_value(&ConstValue::Int(1), &BinaryOp::Mul));
assert!(!is_zero_value(&ConstValue::Int(0), &BinaryOp::Add));
}
#[test]
fn test_get_const_value() {
let c = UOp::native_const(42i32);
assert_eq!(get_const_value(&c), Some(ConstValue::Int(42)));
let x = UOp::param(0, 1, DType::Float32, None);
assert_eq!(get_const_value(&x), None);
}
#[test]
fn test_is_const() {
let c = UOp::native_const(42i32);
assert!(is_const(&c, &ConstValue::Int(42)));
assert!(!is_const(&c, &ConstValue::Int(0)));
}