tract-metal 0.22.2

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
use crate::kernels;
use crate::utils::with_borrowed_metal_stream;
use tract_core::internal::*;
use tract_core::ops::array::Slice;
use tract_gpu::tensor::DeviceTensorExt;

#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
pub struct MetalSlice(Slice);

impl MetalSlice {
    pub fn from_tract_core(op: Slice) -> Self {
        Self(op)
    }
}

impl Op for MetalSlice {
    fn name(&self) -> StaticName {
        "MetalSlice".into()
    }

    fn info(&self) -> TractResult<Vec<String>> {
        self.0.info()
    }

    op_as_typed_op!();

    fn same_as(&self, other: &dyn Op) -> bool {
        if let Some(other) = other.downcast_ref::<Self>() { other == self } else { false }
    }
}

impl EvalOp for MetalSlice {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval_with_session(
        &self,
        node_id: usize,
        session: &SessionState,
        inputs: TVec<TValue>,
    ) -> TractResult<TVec<TValue>> {
        let opaque = args_1!(inputs);
        let input = opaque.to_device_tensor()?;

        let start = self.0.start.eval(&session.resolved_symbols).to_usize()?;
        let end = self.0.end.eval(&session.resolved_symbols).to_usize()?;
        let axis = self.0.axis;

        let input_shape = input.shape();
        let input_strides = input.strides();
        let input_dt = input.datum_type();

        ensure!(
            end <= input_shape[axis] && start <= end,
            "Invalid range {}..{} for slicing {:?} on axis {}",
            start,
            end,
            input,
            axis
        );

        let mut o_shape: TVec<_> = input_shape.into();
        o_shape[axis] = end - start;

        let offset = (start * input_strides[axis] as usize) * input_dt.size_of();

        let output = tract_gpu::session_handler::make_tensor_for_node(
            session,
            node_id,
            input.datum_type(),
            &o_shape,
        )?;

        // Perform slicing only if the output is not empty.
        if o_shape[axis] != 0 {
            with_borrowed_metal_stream(|stream| {
                kernels::array::MultiBroadcast.dispatch_eval(stream, input, offset, &output)
            })?;
        }
        Ok(tvec![output.into_opaque_tensor().into_tvalue()])
    }
}

impl TypedOp for MetalSlice {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        tract_gpu::utils::facts_to_device_facts(inputs, |facts| self.0.output_facts(facts))
            .with_context(|| format!("Error while computing facts for {:?}", self.name()))
    }

    fn concretize_dims(
        &self,
        _source: &TypedModel,
        node: &TypedNode,
        target: &mut TypedModel,
        mapping: &HashMap<OutletId, OutletId>,
        values: &SymbolValues,
    ) -> TractResult<TVec<OutletId>> {
        let op = MetalSlice(Slice {
            axis: self.0.axis,
            start: self.0.start.eval(values),
            end: self.0.end.eval(values),
        });
        let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
        target.wire_node(&node.name, op, &inputs)
    }

    as_op!();
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::utils::with_borrowed_metal_stream;
    use tract_core::internal::Tensor;
    use tract_gpu::tensor::IntoDevice;

    fn run_test(shape: &[usize], slice: Slice) -> TractResult<()> {
        with_borrowed_metal_stream(|stream| {
            let num_elements = shape.iter().product();

            let a = Tensor::from_shape(
                &shape,
                &(0..num_elements).map(|f| f as f32).collect::<Vec<_>>(),
            )?;

            let cpu_output = slice.eval_with_session(
                0,
                &SessionState::default(),
                tvec![a.clone().into_tvalue()],
            )?;

            let metal_slice = MetalSlice::from_tract_core(slice);
            let a_metal = a.clone().into_device()?.into_opaque_tensor().into_tvalue();
            let mut session_state = SessionState::default();
            let metal_output =
                metal_slice.eval_with_session(0, &mut session_state, tvec![a_metal])?;
            stream.wait_until_completed()?;

            cpu_output[0].close_enough(
                &metal_output[0].to_device_tensor()?.to_host()?.into_tensor(),
                Approximation::Approximate,
            )?;
            Ok(())
        })
    }

    #[test]
    fn test_slice() -> TractResult<()> {
        run_test(&[4, 4], Slice { axis: 1, start: 0.into(), end: 4.into() })?;
        run_test(&[8, 3, 5], Slice { axis: 1, start: 1.into(), end: 3.into() })?;
        assert!(run_test(&[8, 3, 5], Slice { axis: 1, start: 1.into(), end: 7.into() }).is_err());
        run_test(&[8, 3, 5], Slice { axis: 1, start: 1.into(), end: 1.into() })?;
        Ok(())
    }
}