#![allow(clippy::bool_comparison)]
#![allow(clippy::unnecessary_cast)]
mod comparison;
mod ite;
pub use comparison::{CompEq, CompGT, CompGTE, CompLT, CompLTE, CompNE};
pub use comparison::{comp_eq, comp_gt, comp_gte, comp_lt, comp_lte, comp_ne};
pub use ite::IfThenElse;
use ndarray::*;
use crate::broadcast::multi_broadcast;
use crate::internal::*;
bin_to_super_type!(and, And,
neutral_element: 1,
absorbing_element: 0,
[bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = (a as i64 != 0 && b as i64 != 0) as _);
bin_to_super_type!(or, Or,
neutral_element: 0,
absorbing_element: 1,
[bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = (a as i64 != 0 || b as i64 != 0) as _);
bin_to_super_type!(xor, Xor, declutter: declutter_xor, neutral_element: 0, [bool] => |c, &a, &b| *c = a ^ b);
fn declutter_xor(
_op: &Xor,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
if tensor0(1i64).close_enough(&uniform.uni, false).is_ok() {
return Ok(Some(TypedModelPatch::replace_single_op(
model,
node,
&[uniform.var],
crate::ops::element_wise::ElementWiseOp(Box::new(Not {}), None),
)?));
}
}
Ok(None)
}
element_wise!(not, Not, [bool] => |_, vs| {
vs.iter_mut().for_each(|a| *a = !*a);
Ok(())
});
#[derive(Debug, Clone, new, Default, Hash, PartialEq, Eq)]
pub struct Iff;
impl Iff {
pub unsafe fn eval_t<T: Datum>(
cond: &ArrayViewD<bool>,
out: &mut Tensor,
t: &Tensor,
f: &Tensor,
) {
unsafe {
Zip::from(out.to_array_view_mut_unchecked::<T>())
.and_broadcast(cond)
.and_broadcast(t.to_array_view_unchecked::<T>())
.and_broadcast(f.to_array_view_unchecked::<T>())
.for_each(|r, c, t, f| *r = if *c { t.clone() } else { f.clone() })
}
}
}
impl Op for Iff {
fn name(&self) -> StaticName {
"Iff".into()
}
op_as_typed_op!();
}
impl EvalOp for Iff {
fn is_stateless(&self) -> bool {
true
}
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let (cond, t, f) = args_3!(inputs);
anyhow::ensure!(t.datum_type() == f.datum_type());
let shape: TVec<usize> = multi_broadcast(&[cond.shape(), t.shape(), f.shape()])?;
unsafe {
let mut result = Tensor::uninitialized_dt(t.datum_type(), &shape)?;
let cond = cond.to_plain_array_view::<bool>()?;
dispatch_datum_by_size!(Self::eval_t(t.datum_type())(&cond, &mut result, &t, &f));
Ok(tvec!(result.into_tvalue()))
}
}
}
pub fn sym_to_coord_axis(sym: &Symbol) -> Option<usize> {
format!("{sym}").strip_prefix("🎯")?.parse::<usize>().ok()
}
pub(crate) fn coord_bound_assertions(expr: &TDim, shape: &ShapeFact) -> Vec<Assertion> {
expr.symbols()
.into_iter()
.filter_map(|s| sym_to_coord_axis(&s).filter(|k| *k < shape.rank()).map(|k| (k, s)))
.flat_map(|(k, sym)| {
[
Assertion::GTE(TDim::Sym(sym.clone()), TDim::Val(0)),
Assertion::LTE(TDim::Sym(sym), shape[k].clone() - TDim::Val(1)),
]
})
.collect()
}
pub(crate) fn is_provably_all_false(expr: &TDim, shape: &ShapeFact) -> bool {
let extra = coord_bound_assertions(expr, shape);
expr.clone().simplify_with_extra_assertions(&extra) == TDim::Val(0)
}
pub(crate) fn is_provably_all_true(expr: &TDim, shape: &ShapeFact) -> bool {
let extra = coord_bound_assertions(expr, shape);
expr.clone().simplify_with_extra_assertions(&extra) == TDim::Val(1)
}
#[derive(Debug, Clone)]
pub(crate) struct TrueRange {
pub axis: usize,
pub start: Option<TDim>, pub end: Option<TDim>, }
impl TrueRange {
pub fn is_full(&self) -> bool {
self.start.is_none() && self.end.is_none()
}
pub fn is_empty(&self) -> bool {
match (&self.start, &self.end) {
(None, Some(e)) => *e == TDim::Val(0),
(Some(s), Some(e)) => s == e,
_ => false,
}
}
}
pub(crate) fn classify_true_range(expr: &TDim, shape: &ShapeFact) -> Option<TrueRange> {
fn try_ge(ge: &TDim, shape: &ShapeFact) -> Option<(usize, TDim)> {
if let TDim::Ge(lhs, rhs) = ge {
if let TDim::Sym(sym) = &**lhs {
let k = sym_to_coord_axis(sym)?;
if k < shape.rank() && !rhs.symbols().contains(sym) {
return Some((k, *rhs.clone()));
}
}
}
None
}
let simplified = expr.clone().simplify();
if simplified == TDim::Val(0) || is_provably_all_false(&simplified, shape) {
return Some(TrueRange { axis: 0, start: None, end: Some(TDim::Val(0)) });
}
if simplified == TDim::Val(1) || is_provably_all_true(&simplified, shape) {
return Some(TrueRange { axis: 0, start: None, end: None });
}
if let Some((axis, split)) = try_ge(&simplified, shape) {
return Some(TrueRange { axis, start: Some(split), end: None });
}
let flipped = (TDim::Val(1) - simplified).simplify();
if let Some((axis, split)) = try_ge(&flipped, shape) {
return Some(TrueRange { axis, start: None, end: Some(split) });
}
None
}
impl TypedOp for Iff {
as_op!();
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
ensure!(inputs.len() == 3, "Iff expects 3 intputs.");
ensure!(inputs[1].datum_type == inputs[2].datum_type);
ensure!(inputs[0].datum_type.is::<bool>());
ensure!(inputs[0].rank() == inputs[1].rank());
ensure!(inputs[0].rank() == inputs[2].rank());
let shape = multi_broadcast(&[
inputs[0].shape.to_tvec(),
inputs[1].shape.to_tvec(),
inputs[2].shape.to_tvec(),
])
.unwrap();
let mut fact = inputs[1].datum_type.fact(shape);
fact.uniform_tdim = match inputs[0].uniform_tdim.as_ref().map(|d| d.clone().simplify()) {
Some(TDim::Val(0)) => inputs[2].uniform_tdim.clone(), Some(TDim::Val(_)) => inputs[1].uniform_tdim.clone(), _ => None,
};
Ok(tvec!(fact))
}
fn input_roi(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TVec<Option<TDim>>>> {
let cond_fact = model.outlet_fact(node.inputs[0])?;
if let Some(cond_expr) = &cond_fact.uniform_tdim {
let cond = cond_expr.clone().simplify();
let not_cond = TDim::Eq(Box::new(cond.clone()), Box::new(TDim::Val(0))).simplify();
return Ok(Some(tvec![None, Some(cond), Some(not_cond)]));
}
crate::optim::propagate_roi::bubble_roi(model, node)
}
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let cond_fact = model.outlet_fact(node.inputs[0])?;
rule_if_some!(uniform = &cond_fact.uniform);
let Ok(cond_val) = uniform.cast_to_scalar::<bool>() else { return Ok(None) };
let branch = if cond_val { node.inputs[1] } else { node.inputs[2] };
let mut patch = TypedModelPatch::default();
let wire = patch.tap_model(model, branch)?;
patch.shunt_outside(model, node.id.into(), wire)?;
Ok(Some(patch))
}
fn axes_mapping(
&self,
inputs: &[&TypedFact],
outputs: &[&TypedFact],
) -> TractResult<AxesMapping> {
AxesMapping::natural(inputs, outputs)
}
}
bin_to_super_type!(bitand, BitAnd,
absorbing_element: 0,
[bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = a & b);
bin_to_super_type!(bitor, BitOr,
neutral_element: 0,
[bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = a | b);
bin_to_super_type!(bitxor, BitXor,
declutter: declutter_bitxor,
neutral_element: 0,
[bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = a ^ b);
fn declutter_bitxor(
_op: &BitXor,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
let var_dt = model.outlet_fact(uniform.var)?.datum_type;
let is_all_ones = if var_dt.is::<bool>() {
tensor0(1i64).close_enough(&uniform.uni, false).is_ok()
} else {
tensor0(-1i64).close_enough(&uniform.uni, false).is_ok()
};
if is_all_ones {
return Ok(Some(TypedModelPatch::replace_single_op(
model,
node,
&[uniform.var],
crate::ops::element_wise::ElementWiseOp(Box::new(BitNot {}), None),
)?));
}
}
Ok(None)
}
element_wise!(bitnot, BitNot, [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |_, xs| {
xs.iter_mut().for_each(|x| *x = !*x);
Ok(())
});
#[cfg(test)]
mod tests {
use super::*;
use crate::ops::array::TypedConcat;
use crate::ops::binary::TypedBinOp;
use crate::ops::change_axes::AxisOp;
#[test]
fn iff_fold_case1_eq_t_zero() -> TractResult<()> {
let mut model = TypedModel::default();
model.symbols.add_assertion("T >= 1")?;
let t_sym = model.symbols.sym("T");
let t_dim = TDim::Sym(t_sym.clone());
let t_wire = model.wire_node(
"T",
crate::ops::konst::Const::new(tensor0(t_dim.clone()).into_arc_tensor())?,
&[],
)?[0];
let zero_wire = model.wire_node(
"zero",
crate::ops::konst::Const::new(tensor0(TDim::Val(0)).into_arc_tensor())?,
&[],
)?[0];
let eq_wire = model.wire_node("eq", TypedBinOp(comp_eq(), None), &[t_wire, zero_wire])?[0];
let data_wire = model.add_source("data", TDim::datum_type().scalar_fact())?;
let iff_wire = model.wire_node("iff", Iff, &[eq_wire, zero_wire, data_wire])?[0];
model.select_output_outlets(&[iff_wire])?;
let model = model.into_decluttered()?;
let iff_count = model.nodes().iter().filter(|n| n.op_as::<Iff>().is_some()).count();
assert_eq!(iff_count, 0, "Expected Iff to be folded, but found {iff_count} Iff nodes");
Ok(())
}
#[test]
fn iff_fold_case2_not_lt_x1_t() -> TractResult<()> {
use crate::ops::array::Range;
let mut model = TypedModel::default();
model.symbols.add_assertion("T >= 1")?;
let t_sym = model.symbols.sym("T");
let t_dim = TDim::Sym(t_sym.clone());
let start = model.wire_node(
"start",
crate::ops::konst::Const::new(tensor0(TDim::Val(0)).into_arc_tensor())?,
&[],
)?[0];
let step = model.wire_node(
"step",
crate::ops::konst::Const::new(tensor0(TDim::Val(1)).into_arc_tensor())?,
&[],
)?[0];
let end = model.add_source("T_dyn", TDim::datum_type().scalar_fact())?;
let range = model.wire_node("range", Range::new(t_dim.clone()), &[start, end, step])?[0];
let range_unsq = model.wire_node("range_unsq", AxisOp::Add(0), &[range])?[0];
let t_const = model.wire_node(
"T_const",
crate::ops::konst::Const::new(tensor0(t_dim.clone()).into_arc_tensor())?,
&[],
)?[0];
let t_unsq = model.wire_node("T_unsq", AxisOp::Add(0), &[t_const])?[0];
let t_unsq2 = model.wire_node("T_unsq2", AxisOp::Add(0), &[t_unsq])?[0];
let lt = model.wire_node("lt", TypedBinOp(comp_lt(), None), &[range_unsq, t_unsq2])?[0];
let bn = model.wire_node("bitnot", bitnot(), &[lt])?[0];
let data_shape = tvec![TDim::Val(1), t_dim.clone()];
let data = model.add_source("data", TDim::datum_type().fact(data_shape.clone()))?;
let zero_scalar = model.wire_node(
"zero_s",
crate::ops::konst::Const::new(tensor0(TDim::Val(0)).into_arc_tensor())?,
&[],
)?[0];
let zeros = model.wire_node(
"zeros",
crate::ops::array::MultiBroadcastTo {
shape: ShapeFact::from_dims(data_shape.iter().cloned()),
},
&[zero_scalar],
)?[0];
let iff = model.wire_node("iff", Iff, &[bn, zeros, data])?[0];
model.select_output_outlets(&[iff])?;
let model = model.into_decluttered()?;
let iff_count = model.nodes().iter().filter(|n| n.op_as::<Iff>().is_some()).count();
assert_eq!(iff_count, 0, "Expected Iff to be folded, but found {iff_count} Iff nodes");
Ok(())
}
#[test]
fn iff_split_to_slice_concat() -> TractResult<()> {
use crate::ops::array::Range;
let mut model = TypedModel::default();
model.symbols.add_assertion("T >= 160")?;
let t_sym = model.symbols.sym("T");
let t_dim = TDim::Sym(t_sym.clone());
let split = t_dim.clone() / 160;
let out_len = TDim::Val(1) + split.clone();
let start = model.wire_node(
"start",
crate::ops::konst::Const::new(tensor0(TDim::Val(0)).into_arc_tensor())?,
&[],
)?[0];
let step = model.wire_node(
"step",
crate::ops::konst::Const::new(tensor0(TDim::Val(1)).into_arc_tensor())?,
&[],
)?[0];
let end_val = model.wire_node(
"end_val",
crate::ops::konst::Const::new(tensor0(out_len.clone()).into_arc_tensor())?,
&[],
)?[0];
let range =
model.wire_node("range", Range::new(out_len.clone()), &[start, end_val, step])?[0];
let r1 = model.wire_node("r1", AxisOp::Add(0), &[range])?[0];
let r2 = model.wire_node("r2", AxisOp::Add(0), &[r1])?[0];
let split_const = model.wire_node(
"split_const",
crate::ops::konst::Const::new(tensor0(split.clone()).into_arc_tensor())?,
&[],
)?[0];
let sc1 = model.wire_node("sc1", AxisOp::Add(0), &[split_const])?[0];
let sc2 = model.wire_node("sc2", AxisOp::Add(0), &[sc1])?[0];
let sc2 = model.wire_node("sc3", AxisOp::Add(0), &[sc2])?[0];
let cond = model.wire_node("cond", TypedBinOp(comp_gte(), None), &[r2, sc2])?[0];
let true_branch = model.add_source(
"true_b",
TDim::datum_type().fact(tvec![TDim::Val(1), TDim::Val(1), out_len.clone()]),
)?;
let false_branch = model.add_source(
"false_b",
TDim::datum_type().fact(tvec![TDim::Val(1), TDim::Val(1), out_len.clone()]),
)?;
let iff = model.wire_node("iff", Iff, &[cond, true_branch, false_branch])?[0];
model.select_output_outlets(&[iff])?;
let model = model.into_decluttered()?;
let iff_count = model.nodes().iter().filter(|n| n.op_as::<Iff>().is_some()).count();
assert_eq!(iff_count, 0, "Expected no Iff nodes after declutter, found {iff_count}");
let concat_count =
model.nodes().iter().filter(|n| n.op_as::<TypedConcat>().is_some()).count();
assert!(concat_count > 0, "Expected at least one Concat node after declutter");
Ok(())
}
#[test]
fn verify_uniform_tdim_propagation() -> TractResult<()> {
use crate::ops::array::Range;
let mut model = TypedModel::default();
model.symbols.add_assertion("T >= 1")?;
let t_sym = model.symbols.sym("T");
let t_dim = TDim::Sym(t_sym.clone());
let start = model.wire_node(
"start",
crate::ops::konst::Const::new(tensor0(TDim::Val(0)).into_arc_tensor())?,
&[],
)?[0];
let step = model.wire_node(
"step",
crate::ops::konst::Const::new(tensor0(TDim::Val(1)).into_arc_tensor())?,
&[],
)?[0];
let end = model.add_source("T_dyn", TDim::datum_type().scalar_fact())?;
let range = model.wire_node("range", Range::new(t_dim.clone()), &[start, end, step])?[0];
let range_unsq = model.wire_node("range_unsq", AxisOp::Add(0), &[range])?[0];
let t_const = model.wire_node(
"T_const",
crate::ops::konst::Const::new(tensor0(t_dim.clone()).into_arc_tensor())?,
&[],
)?[0];
let t_unsq = model.wire_node("T_unsq", AxisOp::Add(0), &[t_const])?[0];
let t_unsq2 = model.wire_node("T_unsq2", AxisOp::Add(0), &[t_unsq])?[0];
let lt = model.wire_node("lt", TypedBinOp(comp_lt(), None), &[range_unsq, t_unsq2])?[0];
let range_fact = model.outlet_fact(range)?;
let range_unsq_fact = model.outlet_fact(range_unsq)?;
let t_unsq_fact = model.outlet_fact(t_unsq)?;
let lt_fact = model.outlet_fact(lt)?;
assert!(range_fact.uniform_tdim.is_some(), "range should have uniform_tdim");
assert!(range_unsq_fact.uniform_tdim.is_some(), "range_unsq should have uniform_tdim");
assert!(t_unsq_fact.uniform_tdim.is_some(), "t_unsq should have uniform_tdim");
assert!(lt_fact.uniform_tdim.is_some(), "lt should have uniform_tdim");
Ok(())
}
}