use morok_ir::{Op, UOp};
use std::sync::Arc;
pub fn is_empty_range(uop: &Arc<UOp>) -> bool {
use morok_dtype::DType;
use morok_ir::types::ConstValue;
use morok_ir::uop::cached_property::CachedProperty;
use morok_ir::uop::properties::VminVmaxProperty;
match uop.op() {
Op::Range { .. } => {
let (_, vmax) = VminVmaxProperty::get(uop);
matches!(vmax, ConstValue::Int(v) if *v < 0)
}
Op::Const(cv) if uop.dtype() == DType::Index => {
matches!(cv.0, ConstValue::Int(0) | ConstValue::UInt(0))
}
_ => false,
}
}
pub fn reduce_identity(op: morok_ir::types::ReduceOp, dtype: morok_dtype::DType) -> Arc<UOp> {
use morok_ir::types::ConstValue::{Float, Int};
use morok_ir::types::ReduceOp;
let val = match op {
ReduceOp::Add => {
if dtype.is_float() {
Float(0.0)
} else {
Int(0)
}
}
ReduceOp::Mul => {
if dtype.is_float() {
Float(1.0)
} else {
Int(1)
}
}
ReduceOp::Max => dtype_min(&dtype),
ReduceOp::Min => dtype_max(&dtype),
};
UOp::const_(dtype, val)
}
fn dtype_min(dtype: &morok_dtype::DType) -> morok_ir::types::ConstValue {
use morok_dtype::ScalarDType;
use morok_ir::types::ConstValue::{Bool, Float, Int, UInt};
if dtype.is_float() {
return Float(f64::NEG_INFINITY);
}
if dtype.is_bool() {
return Bool(false);
}
match dtype.base() {
ScalarDType::Int8 => Int(i8::MIN as i64),
ScalarDType::Int16 => Int(i16::MIN as i64),
ScalarDType::Int32 => Int(i32::MIN as i64),
ScalarDType::Int64 | ScalarDType::Index => Int(i64::MIN),
ScalarDType::UInt8 => UInt(0),
ScalarDType::UInt16 => UInt(0),
ScalarDType::UInt32 => UInt(0),
ScalarDType::UInt64 => UInt(0),
_ => Int(0),
}
}
fn dtype_max(dtype: &morok_dtype::DType) -> morok_ir::types::ConstValue {
use morok_dtype::ScalarDType;
use morok_ir::types::ConstValue::{Bool, Float, Int, UInt};
if dtype.is_float() {
return Float(f64::INFINITY);
}
if dtype.is_bool() {
return Bool(true);
}
match dtype.base() {
ScalarDType::Int8 => Int(i8::MAX as i64),
ScalarDType::Int16 => Int(i16::MAX as i64),
ScalarDType::Int32 => Int(i32::MAX as i64),
ScalarDType::Int64 | ScalarDType::Index => Int(i64::MAX),
ScalarDType::UInt8 => UInt(u8::MAX as u64),
ScalarDType::UInt16 => UInt(u16::MAX as u64),
ScalarDType::UInt32 => UInt(u32::MAX as u64),
ScalarDType::UInt64 => UInt(u64::MAX),
_ => Int(0),
}
}