use tract_itertools::Itertools;
use crate::internal::*;
use crate::ops::array::Slice;
use crate::optim::OptimizerSession;
#[derive(Clone, Debug)]
pub struct PushSliceUp;
impl super::TypedPass for PushSliceUp {
fn reset(&mut self) -> TractResult<()> {
Ok(())
}
fn next(
&mut self,
_session: &mut OptimizerSession,
model: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
let eval_order = model.eval_order()?;
for &n in &eval_order {
let node = model.node(n);
if model.node(n).outputs.len() != 1 {
continue;
}
for axis in 0..node.outputs[0].fact.rank() {
if let Some(succ) = model.single_succ(n)? {
let Some(slice) = succ.op_as::<Slice>() else { continue };
let full_len = &node.outputs[0].fact.shape[axis];
if slice.axis != axis {
continue;
}
if let Some(me) = node.op_as::<Slice>() {
if me.axis == slice.axis {
let start = me.start.clone() + &slice.start;
let len = slice.end.clone() - &slice.start;
let end = start.clone() + len;
let new = Slice { axis, start, end };
return TypedModelPatch::fuse_with_next(model, node, new).map(Some);
} else {
let my_len = &node.outputs[0].fact.shape[me.axis];
let slice_len = &succ.outputs[0].fact.shape[slice.axis];
if !(my_len.clone() - slice_len).prove_strict_positive() {
continue;
}
}
}
let boundaries =
tvec!(0.to_dim(), slice.start.clone(), slice.end.clone(), full_len.clone());
let Some((mut patch, splits)) =
op_slices_to_slice_op(model, node, axis, &boundaries)?
else {
continue;
};
ensure!(splits.len() == 3);
let wire = splits[1];
patch.shunt_outside(model, succ.id.into(), wire)?;
return Ok(Some(patch));
} else if let Some(boundaries) =
should_slice_output(model, node, axis, &eval_order)?
{
let boundaries_dim: TVec<TDim> =
boundaries.iter().map(|d| d.to_dim()).collect();
let Some((mut patch, splits)) =
op_slices_to_slice_op(model, node, axis, &boundaries_dim)?
else {
continue;
};
ensure!(splits.len() == boundaries.len() - 1);
rewire_sliced_outputs(model, node, axis, &mut patch, &boundaries, &splits)
.context("Rewiring sliced outputs")?;
return Ok(Some(patch));
}
}
}
Ok(None)
}
}
fn op_slices_to_slice_op(
model: &TypedModel,
node: &TypedNode,
axis: usize,
boundaries: &[TDim],
) -> TractResult<Option<(TypedModelPatch, TVec<OutletId>)>> {
let (ifacts, ofacts) = model.node_facts(node.id)?;
let invariants = node
.op
.axes_mapping(&ifacts, &ofacts)
.with_context(|| format!("Mapping axes for {node}"))?;
let mut splits = tvec!();
let mut patch = TypedModelPatch::new(format!("Slice {node} by {boundaries:?}"));
let inputs = patch.taps(model, &node.inputs)?;
let len = &node.outputs[0].fact.shape[axis];
ensure!(boundaries[0] == 0.to_dim());
ensure!(boundaries.last().as_ref().unwrap() == &len);
let axis_info = invariants.axis((InOut::Out(0), axis)).unwrap();
for (start, end) in boundaries.iter().tuple_windows() {
let mut wires = tvec!();
for input_ix in 0..inputs.len() {
let mut wire = inputs[input_ix];
if let &[input_axis] = &*axis_info.inputs[input_ix] {
if !patch.outlet_fact(wire)?.shape[input_axis].is_one() {
wire = patch.wire_node(
format!(
"{}.split-{}-over-{}.{}..{}.slice",
&node.name, input_ix, input_axis, start, end
),
Slice { axis: input_axis, start: start.to_dim(), end: end.to_dim() },
&[wire],
)?[0];
}
}
wires.push(wire);
}
rule_if_some!(
wire = node
.op
.slice(
&mut patch,
model,
node,
&format!("{}.split-over-{}.{}..{}", &node.name, axis, start, end),
&wires,
axis,
start,
end,
)
.with_context(|| format!("Calling slice on {node}"))?
);
splits.push(wire[0]);
}
Ok(Some((patch, splits)))
}
fn should_slice_output(
model: &TypedModel,
node: &TypedNode,
axis: usize,
eval_order: &[usize],
) -> TractResult<Option<TVec<usize>>> {
rule_if!(node.outputs[0].successors.len() > 0);
rule_if!(!node.op_is::<Slice>());
let slicers: TVec<usize> = node.outputs[0]
.successors
.iter()
.filter(|inlet| {
model.node(inlet.node).op_as::<Slice>().filter(|slice| slice.axis == axis).is_some()
})
.map(|inlet| inlet.node)
.collect();
rule_if!(slicers.len() >= node.outputs[0].successors.len());
let slice = node.outputs[0].successors[0].node;
rule_if!(eval_order.contains(&slice));
let slice_op = model.node(slice).op_as::<Slice>().unwrap();
let axis = slice_op.axis;
let mut boundaries = tvec!();
for succ in &node.outputs[0].successors {
if let Some(slice) = model.node(succ.node).op_as::<Slice>()
&& slice.axis == axis
{
boundaries.push(slice.start.clone());
boundaries.push(slice.end.clone());
}
}
rule_if_let!(Ok(mut boundaries) =
boundaries.iter().map(|x| x.to_usize()).collect::<TractResult<TVec<usize>>>());
rule_if_let!(Ok(end) = node.outputs[0].fact.shape[axis].to_usize());
boundaries.push(end);
boundaries.sort();
boundaries.dedup();
rule_if!(boundaries.len() != 2);
Ok(Some(boundaries))
}
pub fn rewire_sliced_outputs(
model: &TypedModel,
node: &TypedNode,
axis: usize,
patch: &mut TypedModelPatch,
boundaries: &[usize],
splits: &[OutletId],
) -> TractResult<()> {
let full = patch.wire_node(
format!("{}.concat-{}", node.name, axis),
crate::ops::array::TypedConcat::new(axis),
splits,
)?[0];
patch.shunt_outside(model, node.id.into(), full)?;
let zero = patch.add_const(
format!("{}.zero", node.name),
Tensor::zero_scalar_dt(node.outputs[0].fact.datum_type)?,
)?;
for (ix, succ) in node.outputs[0].successors.iter().enumerate() {
if let Some(slice) =
model.node(succ.node).op_as::<Slice>().filter(|slice| slice.axis == axis)
{
let slices: TVec<OutletId> = boundaries
.iter()
.tuple_windows()
.zip(splits.iter())
.filter_map(|((_down, up), split)| {
if *up > slice.start.to_usize().unwrap() && *up <= slice.end.to_usize().unwrap()
{
Some(*split)
} else {
None
}
})
.collect();
let wire = if slices.len() == 0 {
let mut empty_shape = node.outputs[0].fact.shape.clone();
empty_shape.set(axis, 0.to_dim());
patch.wire_node(
format!("{}.concat-m{}..{}..{}", node.name, ix, slice.start, slice.end),
crate::ops::array::MultiBroadcastTo::new(empty_shape),
&[zero],
)?[0]
} else if slices.len() > 1 {
patch.wire_node(
format!("{}.concat-m{}..{}..{}", node.name, ix, slice.start, slice.end),
crate::ops::array::TypedConcat::new(axis),
&slices,
)?[0]
} else {
slices[0]
};
patch.shunt_outside(model, succ.node.into(), wire)?;
}
}
Ok(())
}