use crate::internal::*;
use crate::ops::array::{Slice, TypedConcat};
use crate::ops::binary::TypedBinOp;
use crate::ops::logic::{And, Iff, classify_true_range};
use crate::optim::OptimizerSession;
#[derive(Clone, Debug, Default)]
pub struct FoldUniformMask(usize);
impl super::TypedPass for FoldUniformMask {
fn reset(&mut self) -> TractResult<()> {
self.0 = 0;
Ok(())
}
fn next(
&mut self,
_session: &mut OptimizerSession,
model: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
for node in &model.nodes[self.0..] {
self.0 = node.id + 1;
if let Some(patch) = try_fold_node(model, node)? {
return Ok(Some(patch));
}
}
Ok(None)
}
}
fn try_fold_node(model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
let bool_indices: &[usize] = if node.op_is::<Iff>() {
&[0] } else if let Some(bin_op) = node.op_as::<TypedBinOp>() {
let out_bool = model.outlet_fact(node.id.into())?.datum_type == bool::datum_type();
if (bin_op.0.is::<And>() && out_bool) || bin_op.0.neutral_element() == Some(1) {
&[0, 1]
} else {
return Ok(None);
}
} else {
return Ok(None);
};
for &bool_ix in bool_indices {
if let Some(patch) = try_fold_uniform_bool_input(model, node, bool_ix)? {
return Ok(Some(patch));
}
}
Ok(None)
}
fn try_fold_uniform_bool_input(
model: &TypedModel,
node: &TypedNode,
bool_ix: usize,
) -> TractResult<Option<TypedModelPatch>> {
let bool_fact = model.outlet_fact(node.inputs[bool_ix])?;
rule_if!(bool_fact.konst.is_none());
rule_if_some!(expr = &bool_fact.uniform_tdim);
rule_if_some!(range = classify_true_range(expr, &bool_fact.shape));
let dt = bool_fact.datum_type;
let rank = bool_fact.rank();
if range.is_full() {
return inject_scalar_const(model, node, bool_ix, dt, rank, true);
}
if range.is_empty() {
return inject_scalar_const(model, node, bool_ix, dt, rank, false);
}
split_op_regions(model, node, bool_ix, dt, rank, &range)
}
fn inject_scalar_const(
model: &TypedModel,
node: &TypedNode,
bool_ix: usize,
bool_dt: DatumType,
bool_rank: usize,
is_true: bool,
) -> TractResult<Option<TypedModelPatch>> {
let const_tensor = uniform_const(bool_dt, bool_rank, is_true)?;
let mut patch = TypedModelPatch::default();
let const_wire = patch.wire_node(
format!("{}.bool_const", node.name),
crate::ops::konst::Const::new(const_tensor.into_arc_tensor())?,
&[],
)?[0];
let mut new_inputs = tvec![];
for (ix, &outlet) in node.inputs.iter().enumerate() {
new_inputs.push(if ix == bool_ix { const_wire } else { patch.tap_model(model, outlet)? });
}
let out = patch.wire_node(&node.name, node.op.clone(), &new_inputs)?[0];
patch.shunt_outside(model, node.id.into(), out)?;
Ok(Some(patch))
}
fn split_op_regions(
model: &TypedModel,
node: &TypedNode,
bool_ix: usize,
bool_dt: DatumType,
bool_rank: usize,
range: &crate::ops::logic::TrueRange,
) -> TractResult<Option<TypedModelPatch>> {
let axis = range.axis;
let out_dim = model.outlet_fact(node.id.into())?.shape[axis].clone();
let regions: TVec<(TDim, TDim, bool)> = match (&range.start, &range.end) {
(None, Some(e)) => {
tvec![(TDim::Val(0), e.clone(), true), (e.clone(), out_dim.clone(), false),]
}
(Some(s), None) => {
tvec![(TDim::Val(0), s.clone(), false), (s.clone(), out_dim.clone(), true),]
}
(Some(s), Some(e)) => tvec![
(TDim::Val(0), s.clone(), false),
(s.clone(), e.clone(), true),
(e.clone(), out_dim.clone(), false),
],
_ => return Ok(None), };
let mut patch = TypedModelPatch::default();
let mut region_outs = tvec![];
for (r_start, r_end, is_true) in ®ions {
let mut region_inputs = tvec![OutletId::new(0, 0); node.inputs.len()];
for (ix, &outlet) in node.inputs.iter().enumerate() {
if ix == bool_ix {
let c = uniform_const(bool_dt, bool_rank, *is_true)?;
region_inputs[ix] = patch.wire_node(
format!("{}.bool_{r_start}", node.name),
crate::ops::konst::Const::new(c.into_arc_tensor())?,
&[],
)?[0];
} else {
let fact = model.outlet_fact(outlet)?;
let wire = patch.tap_model(model, outlet)?;
if fact.shape[axis].is_one() {
region_inputs[ix] = wire;
} else if fact.shape[axis] == out_dim {
region_inputs[ix] = patch.wire_node(
format!("{}.slice_{r_start}_{ix}", node.name),
Slice::new(axis, r_start.clone(), r_end.clone()),
&[wire],
)?[0];
} else {
return Ok(None);
}
}
}
region_outs.push(
patch.wire_node(
format!("{}.region_{r_start}", node.name),
node.op.clone(),
®ion_inputs,
)?[0],
);
}
let result =
patch.wire_node(format!("{}.concat", node.name), TypedConcat::new(axis), ®ion_outs)?[0];
patch.shunt_outside(model, node.id.into(), result)?;
Ok(Some(patch))
}
fn uniform_const(dt: DatumType, rank: usize, is_true: bool) -> TractResult<Tensor> {
tensor0(is_true as i64).cast_to_dt(dt)?.into_owned().broadcast_into_rank(rank)
}