tract-hir 0.23.0-dev.4

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use crate::internal::*;

#[derive(Debug, Clone, new, Default, Hash, PartialEq, Eq)]
pub struct Tile;

impl Expansion for Tile {
    fn name(&self) -> StaticName {
        "Tile".into()
    }

    fn rules<'r, 'p: 'r, 's: 'r>(
        &'s self,
        s: &mut Solver<'r>,
        inputs: &'p [TensorProxy],
        outputs: &'p [TensorProxy],
    ) -> InferenceResult {
        check_input_arity(inputs, 2)?;
        check_output_arity(outputs, 1)?;
        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
        s.equals(&inputs[0].rank, &outputs[0].rank)?;
        s.equals(&inputs[1].rank, 1)?;
        s.equals(&inputs[1].shape[0], inputs[0].rank.bex().to_dim())?;
        s.given(&inputs[1].value, move |s, mult| {
            for (ix, m) in
                mult.cast_to::<TDim>()?.try_as_plain()?.as_slice::<TDim>()?.iter().enumerate()
            {
                if let Some(m) = m.as_i64() {
                    s.equals(m * inputs[0].shape[ix].bex(), &outputs[0].shape[ix])?;
                } else {
                    let m = m.clone();
                    s.given(&inputs[0].shape[ix], move |s, input| {
                        s.equals(input * &m, &outputs[0].shape[ix])
                    })?;
                }
            }
            Ok(())
        })?;
        Ok(())
    }

    fn wire(
        &self,
        prefix: &str,
        target: &mut TypedModel,
        inputs: &[OutletId],
    ) -> TractResult<TVec<OutletId>> {
        if let Some(ref mult) = target.outlet_fact(inputs[1])?.konst {
            let mult: TVec<TDim> =
                mult.cast_to::<TDim>()?.try_as_plain()?.as_slice::<TDim>()?.into();
            target.wire_node(prefix, tract_core::ops::array::Tile::new(mult), &inputs[0..1])
        } else {
            bail!("shape input is variable")
        }
    }
}