use super::Downsample;
use crate::internal::*;
use crate::ops;
pub fn pull_downsample_over_slice(
model: &TypedModel,
slice_node: &TypedNode,
slice_op: &ops::array::Slice,
down_node: &TypedNode,
down_op: &Downsample,
) -> TractResult<Option<TypedModelPatch>> {
if down_op.axis != slice_op.axis {
return Ok(None);
}
if down_op.stride < 0 {
return Ok(None);
}
let modulo = (down_op.modulo + slice_op.start.to_usize()?) % down_op.stride as usize;
let left = (down_op.modulo + slice_op.start.to_usize()?) / down_op.stride as usize;
let mut patch = TypedModelPatch::default();
let tap = patch.tap_model(model, slice_node.inputs[0])?;
let final_len = down_node.outputs[0].fact.shape[down_op.axis].clone();
let new_down = Downsample::new(down_op.axis, down_op.stride, modulo);
let ds = patch.wire_node(&*down_node.name, new_down, [tap].as_ref())?;
let new_start = left;
let new_end = final_len + left;
let op = ops::array::Slice::new(slice_op.axis, new_start.to_dim(), new_end);
let new_slice = patch.wire_node(&*slice_node.name, op, &ds)?[0];
patch.shunt_outside(model, OutletId::new(down_node.id, 0), new_slice)?;
Ok(Some(patch))
}
pub fn pull_downsample_over_axis_op(
model: &TypedModel,
axis_node: &TypedNode,
axis_op: &AxisOp,
down_node: &TypedNode,
down_op: &Downsample,
) -> TractResult<Option<TypedModelPatch>> {
let mut patch = TypedModelPatch::default();
let tap = patch.tap_model(model, axis_node.inputs[0])?;
let mut new_down = down_op.clone();
new_down.axis =
axis_op.recip().transform_axis(down_op.axis).ok_or_else(|| format_err!("Invalid axis"))?;
let wire = patch.wire_node(&*down_node.name, new_down, [tap].as_ref())?;
let wire = patch.wire_node(&*axis_node.name, axis_op.clone(), &wire)?[0];
patch.shunt_outside(model, OutletId::new(down_node.id, 0), wire)?;
Ok(Some(patch))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ops;
use proptest::prelude::*;
use proptest::test_runner::TestCaseResult;
fn crop_then_down_strat() -> BoxedStrategy<(usize, usize, usize, usize, usize)> {
(1usize..5, 1usize..5)
.prop_flat_map(|(cropped, stride)| {
(Just(cropped), 0..=cropped, Just(stride), (cropped + 15)..=(cropped + 15))
})
.prop_flat_map(|(cropped, left, stride, len)| {
(Just(len), Just(left), Just(cropped - left), Just(stride), 0..stride)
})
.boxed()
}
fn crop_then_down(
len: usize,
left: usize,
right: usize,
stride: usize,
modulo: usize,
) -> TestCaseResult {
let _ = env_logger::Builder::from_env("TRACT_LOG").try_init();
let mut model = {
let mut model = TypedModel::default();
let input = model.add_source("input", i32::fact([len])).unwrap();
let crop = model
.wire_node(
"crop",
ops::array::Slice::new(0, left.to_dim(), (len - right).to_dim()),
&[input],
)
.unwrap();
let down = model
.wire_node("down", Downsample::new(0, stride as isize, modulo), &crop)
.unwrap();
model.set_output_outlets(&down).unwrap();
model
};
trace!("{:#?}", model);
prop_assert!(model.node(model.output_outlets().unwrap()[0].node).op_is::<Downsample>());
let input = tensor1(&(0i32..len as _).collect::<Vec<_>>());
let expected = SimplePlan::new(&model).unwrap().run(tvec!(input.clone().into())).unwrap();
info!("Decluttering");
model.declutter().unwrap();
trace!("{:#?}", model);
let order = model.eval_order().unwrap();
prop_assert!(
model.node(order[1]).op_is::<Downsample>()
|| !model.nodes().iter().any(|n| n.op_is::<Downsample>())
);
let found = SimplePlan::new(&model).unwrap().run(tvec!(input.into())).unwrap();
prop_assert_eq!(found, expected);
Ok(())
}
proptest! {
#[test]
fn crop_then_down_prop((len, left, right, stride, modulo) in crop_then_down_strat()) {
crop_then_down(len, left, right, stride, modulo).unwrap()
}
}
#[test]
fn crop_then_down_1() {
crop_then_down(1, 0, 0, 2, 0).unwrap()
}
#[test]
fn crop_then_down_2() {
crop_then_down(2, 0, 1, 2, 0).unwrap()
}
#[test]
fn crop_then_down_3() {
crop_then_down(0, 0, 0, 2, 1).unwrap()
}
#[test]
fn crop_then_down_4() {
crop_then_down(1, 0, 1, 2, 1).unwrap()
}
#[test]
fn crop_then_down_5() {
crop_then_down(16, 0, 1, 2, 1).unwrap()
}
}