use crate::types::{BinaryOp, ConstValue, TernaryOp, UnaryOp};
use crate::{Op, UOp};
use morok_dtype::DType;
use std::cmp::Ordering;
use std::sync::Arc;
pub fn compute_sound_vmin_vmax(uop: &Arc<UOp>) -> Option<(ConstValue, ConstValue)> {
use crate::uop::cached_property::CachedProperty;
use crate::uop::properties::SoundVminVmaxProperty;
match &uop.op {
Op::Const(c) => Some((c.0, c.0)),
Op::VConst { values } => Some(sources_range_values(values, &uop.dtype)),
Op::DefineVar { min_val, max_val, .. } => Some((ConstValue::Int(*min_val), ConstValue::Int(*max_val))),
Op::Range { end, .. } | Op::Special { end, .. } => Some(zero_to_end_minus_one(end, &uop.dtype)),
Op::Unroll { src, .. } | Op::Bind { var: src, .. } | Op::Gep { vector: src, .. } => {
*SoundVminVmaxProperty::get(src)
}
Op::Vectorize { elements } => sound_sources_range(elements),
Op::Cat { sources } => sound_sources_range(sources),
Op::Unary(op, src) => {
let (src_min, src_max) = (*SoundVminVmaxProperty::get(src))?;
match op {
UnaryOp::Neg | UnaryOp::Not => Some(compute_unary_range(*op, src_min, src_max, &uop.dtype)),
_ => None,
}
}
Op::Binary(op, a, b) => {
let (a_min, a_max) = (*SoundVminVmaxProperty::get(a))?;
let (b_min, b_max) = (*SoundVminVmaxProperty::get(b))?;
if a_min == a_max && b_min == b_max {
return Some(compute_binary_range(*op, a_min, a_max, b_min, b_max, &uop.dtype));
}
match op {
BinaryOp::Add
| BinaryOp::Sub
| BinaryOp::Mul
| BinaryOp::Max
| BinaryOp::Mod
| BinaryOp::Idiv
| BinaryOp::Shl
| BinaryOp::Shr
| BinaryOp::Lt
| BinaryOp::Le
| BinaryOp::Eq
| BinaryOp::Ne
| BinaryOp::Gt
| BinaryOp::Ge => Some(compute_binary_range(*op, a_min, a_max, b_min, b_max, &uop.dtype)),
BinaryOp::And | BinaryOp::Or if uop.dtype == DType::Bool => {
Some(compute_binary_range(*op, a_min, a_max, b_min, b_max, &uop.dtype))
}
BinaryOp::And
if uop.dtype.is_int() && b_min == b_max && matches!(b_min, ConstValue::Int(v) if v >= 0) =>
{
Some(compute_binary_range(*op, a_min, a_max, b_min, b_max, &uop.dtype))
}
_ => None,
}
}
Op::Ternary(op, a, b, c) => {
let (a_min, a_max) = (*SoundVminVmaxProperty::get(a))?;
let (b_min, b_max) = (*SoundVminVmaxProperty::get(b))?;
let (c_min, c_max) = (*SoundVminVmaxProperty::get(c))?;
if a_min == a_max && b_min == b_max && c_min == c_max {
return Some(compute_ternary_range(*op, a_min, a_max, b_min, b_max, c_min, c_max, &uop.dtype));
}
match op {
TernaryOp::Where if uop.dtype.is_int() || uop.dtype == DType::Index => {
Some(compute_ternary_range(*op, a_min, a_max, b_min, b_max, c_min, c_max, &uop.dtype))
}
_ => None,
}
}
Op::Cast { src, .. } => {
let dt = &uop.dtype;
if !(dt.is_float() || dt.is_signed() || *dt == DType::Index) {
return None;
}
let (src_min, src_max) = (*SoundVminVmaxProperty::get(src))?;
let has_special = matches!(src_min, ConstValue::Float(f) if f.is_nan() || f.is_infinite())
|| matches!(src_max, ConstValue::Float(f) if f.is_nan() || f.is_infinite());
if has_special {
return None;
}
let (target_min, target_max) = dtype_bounds(dt);
let clamped_min = clamp_value(src_min, target_min, target_max);
let clamped_max = clamp_value(src_max, target_min, target_max);
Some((clamped_min.cast(dt).unwrap_or(target_min), clamped_max.cast(dt).unwrap_or(target_max)))
}
_ => None,
}
}
fn zero_to_end_minus_one(end: &Arc<UOp>, dtype: &DType) -> (ConstValue, ConstValue) {
use crate::uop::cached_property::CachedProperty;
use crate::uop::properties::VminVmaxProperty;
let (_, end_max) = VminVmaxProperty::get(end);
let max = match end_max {
ConstValue::Int(v) => ConstValue::Int(v - 1),
ConstValue::UInt(v) => ConstValue::UInt(v - 1),
_ => dtype_bounds(dtype).1,
};
(ConstValue::Int(0), max)
}
fn sound_sources_range(sources: &[Arc<UOp>]) -> Option<(ConstValue, ConstValue)> {
use crate::uop::cached_property::CachedProperty;
use crate::uop::properties::SoundVminVmaxProperty;
if sources.is_empty() {
return None;
}
let (first_min, first_max) = (*SoundVminVmaxProperty::get(&sources[0]))?;
sources.iter().skip(1).try_fold((first_min, first_max), |(vmin, vmax), src| {
let (s_min, s_max) = (*SoundVminVmaxProperty::get(src))?;
Some((min_value(vmin, s_min), max_value(vmax, s_max)))
})
}
fn sources_range_values(values: &[ConstValue], dtype: &DType) -> (ConstValue, ConstValue) {
if values.is_empty() {
return dtype_bounds(dtype);
}
values.iter().skip(1).fold((values[0], values[0]), |(vmin, vmax), &v| (min_value(vmin, v), max_value(vmax, v)))
}
fn compute_unary_range(op: UnaryOp, vmin: ConstValue, vmax: ConstValue, dtype: &DType) -> (ConstValue, ConstValue) {
use crate::uop::eval::eval_unary_op;
match op {
UnaryOp::Neg => {
let new_min = eval_unary_op(UnaryOp::Neg, vmax).unwrap_or_else(|| dtype_bounds(dtype).0);
let new_max = eval_unary_op(UnaryOp::Neg, vmin).unwrap_or_else(|| dtype_bounds(dtype).1);
(new_min, new_max)
}
UnaryOp::Abs => {
let crosses_zero = match (vmin, vmax) {
(ConstValue::Int(min), ConstValue::Int(max)) => min <= 0 && max >= 0,
(ConstValue::Float(min), ConstValue::Float(max)) => min <= 0.0 && max >= 0.0,
_ => false,
};
if crosses_zero {
let zero = match vmin {
ConstValue::Int(_) => ConstValue::Int(0),
ConstValue::UInt(_) => ConstValue::UInt(0),
ConstValue::Float(_) => ConstValue::Float(0.0),
_ => dtype_bounds(dtype).0,
};
let abs_min = eval_unary_op(UnaryOp::Abs, vmin);
let abs_max = eval_unary_op(UnaryOp::Abs, vmax);
let max_val = match (abs_min, abs_max) {
(Some(a), Some(b)) => {
if compare_const_values(&a, &b) == Ordering::Greater {
a
} else {
b
}
}
_ => dtype_bounds(dtype).1,
};
(zero, max_val)
} else {
let val_min = eval_unary_op(op, vmin);
let val_max = eval_unary_op(op, vmax);
match (val_min, val_max) {
(Some(min), Some(max)) => {
if compare_const_values(&min, &max) == Ordering::Greater {
(max, min)
} else {
(min, max)
}
}
_ => dtype_bounds(dtype),
}
}
}
UnaryOp::Sin | UnaryOp::Cos => {
(ConstValue::Float(-1.0), ConstValue::Float(1.0))
}
UnaryOp::Tan => {
dtype_bounds(dtype)
}
UnaryOp::Erf => {
(ConstValue::Float(-1.0), ConstValue::Float(1.0))
}
UnaryOp::Sign => {
match vmin {
ConstValue::Int(_) => (ConstValue::Int(-1), ConstValue::Int(1)),
ConstValue::Float(_) => (ConstValue::Float(-1.0), ConstValue::Float(1.0)),
ConstValue::UInt(_) => (ConstValue::UInt(0), ConstValue::UInt(1)),
_ => dtype_bounds(dtype),
}
}
UnaryOp::Square => {
let crosses_zero = match (vmin, vmax) {
(ConstValue::Int(min), ConstValue::Int(max)) => min <= 0 && max >= 0,
(ConstValue::Float(min), ConstValue::Float(max)) => min <= 0.0 && max >= 0.0,
_ => false,
};
if crosses_zero {
let zero = match vmin {
ConstValue::Int(_) => ConstValue::Int(0),
ConstValue::UInt(_) => ConstValue::UInt(0),
ConstValue::Float(_) => ConstValue::Float(0.0),
_ => dtype_bounds(dtype).0,
};
let sq_min = eval_unary_op(UnaryOp::Square, vmin);
let sq_max = eval_unary_op(UnaryOp::Square, vmax);
let max_val = match (sq_min, sq_max) {
(Some(a), Some(b)) => {
if compare_const_values(&a, &b) == Ordering::Greater {
a
} else {
b
}
}
_ => dtype_bounds(dtype).1,
};
(zero, max_val)
} else {
let val_min = eval_unary_op(op, vmin);
let val_max = eval_unary_op(op, vmax);
match (val_min, val_max) {
(Some(min), Some(max)) => {
if compare_const_values(&min, &max) == Ordering::Greater {
(max, min)
} else {
(min, max)
}
}
_ => dtype_bounds(dtype),
}
}
}
UnaryOp::Not => {
let new_min = eval_unary_op(UnaryOp::Not, vmax).unwrap_or_else(|| dtype_bounds(dtype).0);
let new_max = eval_unary_op(UnaryOp::Not, vmin).unwrap_or_else(|| dtype_bounds(dtype).1);
(new_min, new_max)
}
UnaryOp::Sqrt
| UnaryOp::Rsqrt
| UnaryOp::Exp
| UnaryOp::Exp2
| UnaryOp::Log
| UnaryOp::Log2
| UnaryOp::Reciprocal
| UnaryOp::Trunc
| UnaryOp::Floor
| UnaryOp::Ceil
| UnaryOp::Round => {
let val_min = eval_unary_op(op, vmin);
let val_max = eval_unary_op(op, vmax);
match (val_min, val_max) {
(Some(min), Some(max)) => {
if compare_const_values(&min, &max) == Ordering::Greater { (max, min) } else { (min, max) }
}
_ => dtype_bounds(dtype),
}
}
}
}
fn compute_binary_range(
op: BinaryOp,
a_min: ConstValue,
a_max: ConstValue,
b_min: ConstValue,
b_max: ConstValue,
dtype: &DType,
) -> (ConstValue, ConstValue) {
use crate::uop::eval::eval_binary_op;
if a_min == a_max
&& b_min == b_max
&& !matches!(op, BinaryOp::Lt | BinaryOp::Le | BinaryOp::Eq | BinaryOp::Ne | BinaryOp::Gt | BinaryOp::Ge)
{
if let Some(val) = eval_binary_op(op, a_min, b_min) {
return (val, val);
}
return dtype_bounds(dtype);
}
match op {
BinaryOp::Add => {
match (a_min, a_max, b_min, b_max) {
(ConstValue::Int(amin), ConstValue::Int(amax), ConstValue::Int(bmin), ConstValue::Int(bmax)) => {
match (amin.checked_add(bmin), amax.checked_add(bmax)) {
(Some(min), Some(max)) => (ConstValue::Int(min), ConstValue::Int(max)),
_ => dtype_bounds(dtype), }
}
(ConstValue::UInt(amin), ConstValue::UInt(amax), ConstValue::UInt(bmin), ConstValue::UInt(bmax)) => {
match (amin.checked_add(bmin), amax.checked_add(bmax)) {
(Some(min), Some(max)) => (ConstValue::UInt(min), ConstValue::UInt(max)),
_ => dtype_bounds(dtype), }
}
_ => {
let min = eval_binary_op(BinaryOp::Add, a_min, b_min).unwrap_or_else(|| dtype_bounds(dtype).0);
let max = eval_binary_op(BinaryOp::Add, a_max, b_max).unwrap_or_else(|| dtype_bounds(dtype).1);
(min, max)
}
}
}
BinaryOp::Sub => {
match (a_min, a_max, b_min, b_max) {
(ConstValue::Int(amin), ConstValue::Int(amax), ConstValue::Int(bmin), ConstValue::Int(bmax)) => {
match (amin.checked_sub(bmax), amax.checked_sub(bmin)) {
(Some(min), Some(max)) => (ConstValue::Int(min), ConstValue::Int(max)),
_ => dtype_bounds(dtype), }
}
(ConstValue::UInt(amin), ConstValue::UInt(amax), ConstValue::UInt(bmin), ConstValue::UInt(bmax)) => {
match (amin.checked_sub(bmax), amax.checked_sub(bmin)) {
(Some(min), Some(max)) => (ConstValue::UInt(min), ConstValue::UInt(max)),
_ => dtype_bounds(dtype), }
}
_ => {
let min = eval_binary_op(BinaryOp::Sub, a_min, b_max).unwrap_or_else(|| dtype_bounds(dtype).0);
let max = eval_binary_op(BinaryOp::Sub, a_max, b_min).unwrap_or_else(|| dtype_bounds(dtype).1);
(min, max)
}
}
}
BinaryOp::Max => {
let min = eval_binary_op(BinaryOp::Max, a_min, b_min).unwrap_or_else(|| dtype_bounds(dtype).0);
let max = eval_binary_op(BinaryOp::Max, a_max, b_max).unwrap_or_else(|| dtype_bounds(dtype).1);
(min, max)
}
BinaryOp::Mul | BinaryOp::Pow => eval_four_corners(op, a_min, a_max, b_min, b_max, dtype),
BinaryOp::Idiv | BinaryOp::Fdiv => {
if contains_zero(b_min, b_max) {
dtype_bounds(dtype)
} else {
eval_four_corners(op, a_min, a_max, b_min, b_max, dtype)
}
}
BinaryOp::Mod => {
match (a_min, a_max, b_min, b_max) {
(ConstValue::Int(a_lo), ConstValue::Int(a_hi), ConstValue::Int(b_lo), ConstValue::Int(b_hi))
if a_lo >= 0 && b_lo > 0 =>
{
(ConstValue::Int(0), ConstValue::Int(a_hi.min(b_hi - 1)))
}
(ConstValue::Int(_a_lo), ConstValue::Int(a_hi), ConstValue::Int(b_lo), ConstValue::Int(b_hi))
if a_hi <= 0 && b_lo > 0 =>
{
(ConstValue::Int(-(b_hi - 1)), ConstValue::Int(0))
}
(ConstValue::Int(_), ConstValue::Int(_), ConstValue::Int(b_lo), ConstValue::Int(b_hi)) if b_lo > 0 => {
(ConstValue::Int(-(b_hi - 1)), ConstValue::Int(b_hi - 1))
}
(ConstValue::UInt(_), ConstValue::UInt(a_hi), ConstValue::UInt(b_lo), ConstValue::UInt(b_hi))
if b_lo > 0 =>
{
(ConstValue::UInt(0), ConstValue::UInt(a_hi.min(b_hi - 1)))
}
_ => dtype_bounds(dtype),
}
}
BinaryOp::Lt | BinaryOp::Le | BinaryOp::Eq | BinaryOp::Ne | BinaryOp::Gt | BinaryOp::Ge => {
use crate::uop::comparison_analysis::ComparisonAnalyzer;
ComparisonAnalyzer::get_comparison_range(op, a_min, a_max, b_min, b_max)
}
BinaryOp::And | BinaryOp::Or | BinaryOp::Xor => compute_bitwise_range(op, a_min, a_max, b_min, b_max, dtype),
BinaryOp::Shl | BinaryOp::Shr => compute_shift_range(op, a_min, a_max, b_min, b_max, dtype),
BinaryOp::Threefry => dtype_bounds(dtype),
}
}
fn compute_bitwise_range(
op: BinaryOp,
a_min: ConstValue,
a_max: ConstValue,
b_min: ConstValue,
b_max: ConstValue,
dtype: &DType,
) -> (ConstValue, ConstValue) {
if dtype == &DType::Bool {
eval_four_corners(op, a_min, a_max, b_min, b_max, dtype)
} else {
match op {
BinaryOp::And => {
if let (ConstValue::Int(bmin), ConstValue::Int(bmax)) = (b_min, b_max)
&& bmin == bmax
&& bmin >= 0
{
return (ConstValue::Int(0), ConstValue::Int(bmax));
}
dtype_bounds(dtype)
}
_ => dtype_bounds(dtype), }
}
}
fn compute_shift_range(
op: BinaryOp,
a_min: ConstValue,
a_max: ConstValue,
b_min: ConstValue,
b_max: ConstValue,
dtype: &DType,
) -> (ConstValue, ConstValue) {
let bit_width = if dtype == &DType::Int8 || dtype == &DType::UInt8 {
8
} else if dtype == &DType::Int16 || dtype == &DType::UInt16 {
16
} else if dtype == &DType::Int32 || dtype == &DType::UInt32 {
32
} else if dtype == &DType::Int64 || dtype == &DType::UInt64 {
64
} else {
return dtype_bounds(dtype); };
match (b_min, b_max) {
(ConstValue::Int(shift_min), ConstValue::Int(shift_max)) if shift_min >= 0 && shift_max < bit_width as i64 => {
eval_four_corners(op, a_min, a_max, b_min, b_max, dtype)
}
(ConstValue::UInt(shift_min), ConstValue::UInt(shift_max))
if shift_min == 0 && shift_max < bit_width as u64 =>
{
eval_four_corners(op, a_min, a_max, b_min, b_max, dtype)
}
_ => dtype_bounds(dtype), }
}
#[allow(clippy::too_many_arguments)]
fn compute_ternary_range(
op: TernaryOp,
cond_min: ConstValue,
cond_max: ConstValue,
true_min: ConstValue,
true_max: ConstValue,
false_min: ConstValue,
false_max: ConstValue,
dtype: &DType,
) -> (ConstValue, ConstValue) {
match op {
TernaryOp::Where => {
match (cond_min, cond_max) {
(ConstValue::Bool(true), ConstValue::Bool(true)) => (true_min, true_max),
(ConstValue::Bool(false), ConstValue::Bool(false)) => (false_min, false_max),
_ => {
let candidates = [true_min, true_max, false_min, false_max];
range_union(&candidates)
}
}
}
TernaryOp::MulAcc => {
use crate::uop::eval::eval_ternary_op;
let corners = [
(cond_min, true_min, false_min),
(cond_min, true_min, false_max),
(cond_min, true_max, false_min),
(cond_min, true_max, false_max),
(cond_max, true_min, false_min),
(cond_max, true_min, false_max),
(cond_max, true_max, false_min),
(cond_max, true_max, false_max),
];
let mut min = None;
let mut max = None;
for &(a, b, c) in &corners {
if let Some(val) = eval_ternary_op(TernaryOp::MulAcc, a, b, c) {
min = Some(min.map_or(val, |m| min_value(m, val)));
max = Some(max.map_or(val, |m| max_value(m, val)));
}
}
min.zip(max).unwrap_or_else(|| dtype_bounds(dtype))
}
}
}
fn eval_four_corners(
op: BinaryOp,
a_min: ConstValue,
a_max: ConstValue,
b_min: ConstValue,
b_max: ConstValue,
dtype: &DType,
) -> (ConstValue, ConstValue) {
use crate::uop::eval::eval_binary_op;
let corners = [(a_min, b_min), (a_min, b_max), (a_max, b_min), (a_max, b_max)];
let mut min = None;
let mut max = None;
for &(a, b) in &corners {
if let Some(val) = eval_binary_op(op, a, b) {
min = Some(min.map_or(val, |m| min_value(m, val)));
max = Some(max.map_or(val, |m| max_value(m, val)));
}
}
min.zip(max).unwrap_or_else(|| dtype_bounds(dtype))
}
pub fn dtype_bounds(dtype: &DType) -> (ConstValue, ConstValue) {
let s = dtype.base();
(ConstValue::min(s), ConstValue::max(s))
}
fn min_value(a: ConstValue, b: ConstValue) -> ConstValue {
if compare_const_values(&a, &b) == Ordering::Less { a } else { b }
}
fn max_value(a: ConstValue, b: ConstValue) -> ConstValue {
if compare_const_values(&a, &b) == Ordering::Greater { a } else { b }
}
fn range_union(values: &[ConstValue]) -> (ConstValue, ConstValue) {
let min = values.iter().copied().reduce(min_value).unwrap();
let max = values.iter().copied().reduce(max_value).unwrap();
(min, max)
}
fn compare_const_values(a: &ConstValue, b: &ConstValue) -> Ordering {
match (a, b) {
(ConstValue::Int(x), ConstValue::Int(y)) => x.cmp(y),
(ConstValue::UInt(x), ConstValue::UInt(y)) => x.cmp(y),
(ConstValue::Float(x), ConstValue::Float(y)) => {
if x.is_nan() && y.is_nan() {
Ordering::Equal
} else if x.is_nan() {
Ordering::Greater } else if y.is_nan() {
Ordering::Less
} else {
x.partial_cmp(y).unwrap_or(Ordering::Equal)
}
}
(ConstValue::Bool(x), ConstValue::Bool(y)) => x.cmp(y),
_ => Ordering::Equal, }
}
fn contains_zero(min: ConstValue, max: ConstValue) -> bool {
match (min, max) {
(ConstValue::Int(min_v), ConstValue::Int(max_v)) => min_v <= 0 && max_v >= 0,
(ConstValue::UInt(min_v), _) => min_v == 0, (ConstValue::Float(min_v), ConstValue::Float(max_v)) => min_v <= 0.0 && max_v >= 0.0,
_ => false,
}
}
fn clamp_value(v: ConstValue, min: ConstValue, max: ConstValue) -> ConstValue {
if compare_const_values(&v, &min) == Ordering::Less {
min
} else if compare_const_values(&v, &max) == Ordering::Greater {
max
} else {
v
}
}