use crate::UOp;
use crate::types::{BinaryOp, ConstValue};
use std::cmp::Ordering;
use std::sync::Arc;
pub struct ComparisonAnalyzer;
impl ComparisonAnalyzer {
pub fn analyze(op: BinaryOp, x: &Arc<UOp>, y: &Arc<UOp>) -> Option<bool> {
use crate::uop::cached_property::CachedProperty;
use crate::uop::properties::VminVmaxProperty;
if Arc::ptr_eq(x, y) && !x.dtype().is_float() {
return match op {
BinaryOp::Lt => Some(false), BinaryOp::Eq => Some(true), BinaryOp::Ne => Some(false), _ => None,
};
}
if !Self::can_eliminate_comparison(x, y) {
return None;
}
let (x_min, x_max) = VminVmaxProperty::get(x);
let (y_min, y_max) = VminVmaxProperty::get(y);
Self::analyze_with_ranges(op, *x_min, *x_max, *y_min, *y_max)
}
pub fn analyze_with_ranges(
op: BinaryOp,
x_min: ConstValue,
x_max: ConstValue,
y_min: ConstValue,
y_max: ConstValue,
) -> Option<bool> {
use BinaryOp::*;
match op {
Lt => Self::analyze_lt(x_min, x_max, y_min, y_max),
Le => Self::analyze_le(x_min, x_max, y_min, y_max),
Eq => Self::analyze_eq(x_min, x_max, y_min, y_max),
Ne => Self::analyze_ne(x_min, x_max, y_min, y_max),
Gt => Self::analyze_gt(x_min, x_max, y_min, y_max),
Ge => Self::analyze_ge(x_min, x_max, y_min, y_max),
_ => None,
}
}
pub fn analyze_extended(
op_name: &str,
x_min: ConstValue,
x_max: ConstValue,
y_min: ConstValue,
y_max: ConstValue,
) -> Option<bool> {
match op_name {
"le" => Self::analyze_le(x_min, x_max, y_min, y_max),
"gt" => Self::analyze_gt(x_min, x_max, y_min, y_max),
"ge" => Self::analyze_ge(x_min, x_max, y_min, y_max),
_ => None,
}
}
pub fn get_comparison_range(
op: BinaryOp,
x_min: ConstValue,
x_max: ConstValue,
y_min: ConstValue,
y_max: ConstValue,
) -> (ConstValue, ConstValue) {
let has_nan = [&x_min, &x_max, &y_min, &y_max].iter().any(|v| matches!(v, ConstValue::Float(f) if f.is_nan()));
if has_nan {
return (ConstValue::Bool(false), ConstValue::Bool(true));
}
match Self::analyze_with_ranges(op, x_min, x_max, y_min, y_max) {
Some(true) => (ConstValue::Bool(true), ConstValue::Bool(true)),
Some(false) => (ConstValue::Bool(false), ConstValue::Bool(false)),
None => (ConstValue::Bool(false), ConstValue::Bool(true)),
}
}
fn can_eliminate_comparison(x: &Arc<UOp>, y: &Arc<UOp>) -> bool {
let dtype = x.dtype();
if !dtype.is_float() {
return true;
}
use crate::uop::cached_property::CachedProperty;
use crate::uop::properties::VminVmaxProperty;
let check_nan = |uop: &Arc<UOp>| {
let (min, max) = VminVmaxProperty::get(uop);
matches!(min, ConstValue::Float(f) if f.is_nan()) || matches!(max, ConstValue::Float(f) if f.is_nan())
};
!check_nan(x) && !check_nan(y)
}
fn analyze_lt(x_min: ConstValue, x_max: ConstValue, y_min: ConstValue, y_max: ConstValue) -> Option<bool> {
use Ordering::*;
match (Self::compare_values(&x_max, &y_min), Self::compare_values(&x_min, &y_max)) {
(Less, _) => Some(true), (_, ord) if ord != Less => Some(false), _ => None,
}
}
fn analyze_le(x_min: ConstValue, x_max: ConstValue, y_min: ConstValue, y_max: ConstValue) -> Option<bool> {
use Ordering::*;
match (Self::compare_values(&x_max, &y_min), Self::compare_values(&x_min, &y_max)) {
(ord, _) if ord != Greater => Some(true), (_, Greater) => Some(false), _ => None,
}
}
fn analyze_gt(x_min: ConstValue, x_max: ConstValue, y_min: ConstValue, y_max: ConstValue) -> Option<bool> {
Self::analyze_lt(y_min, y_max, x_min, x_max)
}
fn analyze_ge(x_min: ConstValue, x_max: ConstValue, y_min: ConstValue, y_max: ConstValue) -> Option<bool> {
Self::analyze_le(y_min, y_max, x_min, x_max)
}
fn analyze_eq(x_min: ConstValue, x_max: ConstValue, y_min: ConstValue, y_max: ConstValue) -> Option<bool> {
let no_overlap = Self::ranges_disjoint(x_min, x_max, y_min, y_max);
if no_overlap {
Some(false) } else if x_min == x_max && y_min == y_max && x_min == y_min {
Some(true) } else {
None }
}
fn analyze_ne(x_min: ConstValue, x_max: ConstValue, y_min: ConstValue, y_max: ConstValue) -> Option<bool> {
Self::analyze_eq(x_min, x_max, y_min, y_max).map(|v| !v)
}
fn ranges_disjoint(x_min: ConstValue, x_max: ConstValue, y_min: ConstValue, y_max: ConstValue) -> bool {
use Ordering::*;
matches!((Self::compare_values(&x_max, &y_min), Self::compare_values(&x_min, &y_max)), (Less, _) | (_, Greater))
}
fn compare_values(a: &ConstValue, b: &ConstValue) -> Ordering {
match (a, b) {
(ConstValue::Int(x), ConstValue::Int(y)) => x.cmp(y),
(ConstValue::UInt(x), ConstValue::UInt(y)) => x.cmp(y),
(ConstValue::Float(x), ConstValue::Float(y)) => {
debug_assert!(!x.is_nan() && !y.is_nan(), "NaN should have been filtered by can_eliminate_comparison");
if x.is_nan() || y.is_nan() {
Ordering::Equal
} else {
x.partial_cmp(y).unwrap_or(Ordering::Equal)
}
}
(ConstValue::Bool(x), ConstValue::Bool(y)) => x.cmp(y),
_ => Ordering::Equal, }
}
}
pub fn analyze_all_comparisons(x: &Arc<UOp>, y: &Arc<UOp>) -> (Option<bool>, Option<bool>, Option<bool>) {
let lt = ComparisonAnalyzer::analyze(BinaryOp::Lt, x, y);
let eq = ComparisonAnalyzer::analyze(BinaryOp::Eq, x, y);
let ne = ComparisonAnalyzer::analyze(BinaryOp::Ne, x, y);
(lt, eq, ne)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lt_analysis() {
assert_eq!(
ComparisonAnalyzer::analyze_with_ranges(
BinaryOp::Lt,
ConstValue::Int(100),
ConstValue::Int(200),
ConstValue::Int(50),
ConstValue::Int(50)
),
Some(false)
);
assert_eq!(
ComparisonAnalyzer::analyze_with_ranges(
BinaryOp::Lt,
ConstValue::Int(0),
ConstValue::Int(10),
ConstValue::Int(20),
ConstValue::Int(30)
),
Some(true)
);
}
#[test]
fn test_eq_analysis() {
assert_eq!(
ComparisonAnalyzer::analyze_with_ranges(
BinaryOp::Eq,
ConstValue::Int(0),
ConstValue::Int(10),
ConstValue::Int(20),
ConstValue::Int(30)
),
Some(false)
);
assert_eq!(
ComparisonAnalyzer::analyze_with_ranges(
BinaryOp::Eq,
ConstValue::Int(5),
ConstValue::Int(5),
ConstValue::Int(5),
ConstValue::Int(5)
),
Some(true)
);
}
}