tract-onnx 0.19.2

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use crate::model::ParsingContext;
use crate::pb::*;
use std::hash::Hash;
use tract_hir::internal::*;

pub fn resize(
    _ctx: &ParsingContext,
    node: &NodeProto,
) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
    let coord_transformer =
        match node.get_attr_opt("coordinate_transformation_mode")?.unwrap_or("half_pixel") {
            "align_corners" => CoordTransformer::AlignCorners,
            "half_pixel" => CoordTransformer::HalfPixel,
            "asymmetric" => CoordTransformer::Asymmetric,
            s => todo!("coordinate_transformation_mode: {}", s),
        };
    let interpolator = match node.get_attr_opt("mode")?.unwrap_or("nearest") {
        "nearest" => Interpolator::Nearest,
        "linear" => Interpolator::Linear,
        s => todo!("mode: {}", s),
    };
    let nearest = match node.get_attr_opt("nearest_mode")?.unwrap_or("round_prefer_floor") {
        "floor" => Nearest::Floor,
        "ceil" => Nearest::Ceil,
        "round_prefer_floor" => Nearest::RoundPreferFloor,
        "round_prefer_ceil" => Nearest::RoundPreferCeil,
        s => todo!("nearest_mode: {}", s),
    };
    let mut options = crate::model::optional_inputs(node).skip(2);
    Ok((
        Box::new(Resize {
            optional_scales_input: options.next().unwrap(),
            optional_sizes_input: options.next().unwrap(),
            coord_transformer,
            interpolator,
            nearest,
        }),
        vec![],
    ))
}

#[derive(Clone, Debug, Hash)]
enum CoordTransformer {
    HalfPixel,
    AlignCorners,
    Asymmetric,
}

impl CoordTransformer {
    fn transform(&self, x_out: usize, scale: f32, len_in: usize, len_out: usize) -> f32 {
        match self {
            CoordTransformer::HalfPixel => (x_out as f32 + 0.5) / scale - 0.5,
            CoordTransformer::AlignCorners => {
                (x_out as f32 * (len_in as f32 - 1.0)) / (len_out as f32 - 1.0)
            }
            CoordTransformer::Asymmetric => (x_out as f32) / scale,
        }
    }
}

#[derive(Clone, Debug, Hash)]
enum Interpolator {
    Linear,
    Nearest,
}

impl Interpolator {
    fn interpolate(&self, y_left: f32, y_right: f32, x_ratio: f32, nearest_mode: Nearest) -> f32 {
        match self {
            Interpolator::Linear => y_left * (1.0 - x_ratio) + y_right * x_ratio,
            Interpolator::Nearest => match nearest_mode {
                Nearest::Floor => y_left,
                Nearest::Ceil => y_right,
                Nearest::RoundPreferFloor => {
                    if x_ratio <= 0.5 {
                        y_left
                    } else {
                        y_right
                    }
                }
                Nearest::RoundPreferCeil => {
                    if x_ratio < 0.5 {
                        y_left
                    } else {
                        y_right
                    }
                }
            },
        }
    }
}

#[derive(Clone, Copy, Debug, Hash)]
enum Nearest {
    Floor,
    Ceil,
    RoundPreferFloor,
    RoundPreferCeil,
}

#[derive(Clone, new, Debug, Hash)]
struct Resize {
    coord_transformer: CoordTransformer,
    interpolator: Interpolator,
    nearest: Nearest,
    optional_scales_input: Option<usize>,
    optional_sizes_input: Option<usize>,
}

impl_dyn_hash!(Resize);

impl Op for Resize {
    fn name(&self) -> Cow<str> {
        "Resize".into()
    }

    op_as_typed_op!();
}

impl Resize {
    fn compute_output_shape(
        &self,
        input_shape: &[usize],
        input_scale: Option<&Tensor>,
        input_sizes: Option<&Tensor>,
    ) -> TractResult<TVec<usize>> {
        if let Some(scale) = input_scale {
            if scale.len() == input_shape.len() {
                let scales = scale.cast_to::<f32>()?;
                return Ok(input_shape
                    .iter()
                    .zip(scales.as_slice::<f32>()?.iter())
                    .map(|(input, scale)| ((*input as f32) * scale) as usize)
                    .collect());
            }
        }
        if let Some(sizes) = input_sizes {
            if sizes.len() == input_shape.len() {
                let size = sizes.cast_to::<i64>()?;
                return Ok(size.as_slice::<i64>()?.iter().map(|i| *i as usize).collect());
            }
        }
        bail!(
            "Neither shape not scale makes sense: input_shape: {:?}, scale: {:?}, sizes: {:?}",
            input_shape,
            input_scale,
            input_sizes,
        );
    }
}

impl EvalOp for Resize {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let scales = self.optional_scales_input.and_then(|ix| inputs.get(ix));
        let sizes = self.optional_sizes_input.and_then(|ix| inputs.get(ix));
        let output_shape = self.compute_output_shape(
            inputs[0].shape(),
            scales.map(|t| &**t),
            sizes.map(|t| &**t),
        )?;
        let mut data = inputs.remove(0).into_tensor().into_array::<f32>()?;
        for axis in 0..data.ndim() {
            #[allow(clippy::comparison_chain)]
            if output_shape[axis] == data.shape()[axis] {
                continue;
            } else if output_shape[axis] > data.shape()[axis] {
                let scale = output_shape[axis] as f32 / data.shape()[axis] as f32;
                let mut new_shape: TVec<usize> = data.shape().into();
                new_shape[axis] = output_shape[axis];
                data = tract_ndarray::ArrayD::from_shape_fn(&*new_shape, |co_o| -> f32 {
                    let x_out = co_o[axis];
                    let x_in = self.coord_transformer.transform(
                        x_out,
                        scale,
                        data.shape()[axis],
                        new_shape[axis],
                    );
                    let mut co_i = co_o;
                    let x_left = (x_in as usize).clamp(0, data.shape()[axis] - 1);
                    co_i[axis] = x_left;
                    let y_left = data[&co_i];
                    let x_right = (x_left + 1).min(data.shape()[axis] - 1);
                    co_i[axis] = x_right;
                    let y_right = data[&co_i];
                    let x_frac = x_in - x_left as f32;
                    self.interpolator.interpolate(y_left, y_right, x_frac, self.nearest)
                })
            }
        }
        Ok(tvec!(data.into_tvalue()))
    }
}

impl InferenceRulesOp for Resize {
    fn rules<'r, 'p: 'r, 's: 'r>(
        &'s self,
        s: &mut Solver<'r>,
        inputs: &'p [TensorProxy],
        outputs: &'p [TensorProxy],
    ) -> InferenceResult {
        check_output_arity(outputs, 1)?;
        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
        s.equals(&inputs[0].rank, &outputs[0].rank)?;
        if inputs.len() == 3 && self.optional_scales_input == Some(2) {
            rules_with_scales(self, s, inputs, outputs)
        } else if inputs.len() == 3 && self.optional_sizes_input == Some(2) {
            rules_with_sizes(self, s, inputs, outputs)
        } else {
            // bogus 4 inputs case
            s.given_2(
                &inputs[0].rank,
                &inputs[self.optional_scales_input.unwrap()].shape,
                move |s, input_rank, scale_shape| {
                    if scale_shape.len() == 0 || scale_shape[0] != input_rank.to_dim() {
                        rules_with_sizes(self, s, inputs, outputs)
                    } else {
                        rules_with_scales(self, s, inputs, outputs)
                    }
                },
            )
        }
    }

    as_op!();
    to_typed!();
}

fn rules_with_scales<'r, 'p: 'r, 's: 'r>(
    op: &'s Resize,
    s: &mut Solver<'r>,
    inputs: &'p [TensorProxy],
    outputs: &'p [TensorProxy],
) -> InferenceResult {
    let scales = &inputs[op.optional_scales_input.unwrap()];
    s.equals(&scales.datum_type, f32::datum_type())?;
    s.equals(&scales.rank, 1)?;
    s.equals(&scales.shape[0], inputs[0].rank.bex().to_dim())?;
    s.given_2(
        &inputs[0].shape,
        &inputs[op.optional_scales_input.unwrap()].value,
        move |s, input_shape, scales| {
            let input_shape =
                input_shape.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<usize>>>()?;
            let output_size = op.compute_output_shape(&input_shape, Some(scales.as_ref()), None)?;
            let rank = input_shape.len();
            for i in 0..rank {
                s.equals(&outputs[0].shape[i], output_size[i].to_dim())?;
            }
            Ok(())
        },
    )
}

fn rules_with_sizes<'r, 'p: 'r, 's: 'r>(
    op: &'s Resize,
    s: &mut Solver<'r>,
    inputs: &'p [TensorProxy],
    outputs: &'p [TensorProxy],
) -> InferenceResult {
    let sizes = &inputs[op.optional_sizes_input.unwrap()];
    s.equals(&sizes.rank, 1)?;
    s.equals(&sizes.shape[0], inputs[0].rank.bex().to_dim())?;
    s.given(&inputs[0].rank, move |s, rank| {
        for i in 0..(rank as usize) {
            s.equals(&outputs[0].shape[i], sizes.value[i].bex().to_dim())?;
        }
        Ok(())
    })
}

impl TypedOp for Resize {
    as_op!();

    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let input_shape = if let Some(s) = inputs[0].shape.as_concrete() {
            s
        } else {
            bail!("Only constant input shape are supported in Resize")
        };
        let scales = self.optional_scales_input.and_then(|ix| inputs.get(ix));
        let sizes = self.optional_sizes_input.and_then(|ix| inputs.get(ix));
        let output_shape = self.compute_output_shape(
            input_shape,
            scales.and_then(|f| f.konst.as_deref()),
            sizes.and_then(|f| f.konst.as_deref()),
        )?;
        Ok(tvec!(inputs[0].datum_type.fact(&output_shape)))
    }

    fn declutter(
        &self,
        _model: &TypedModel,
        _node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        Ok(None)
    }
}