tract-metal 0.22.2

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
use crate::kernels::nn::Reducer;
use crate::utils::with_borrowed_metal_stream;
use tract_core::internal::*;
use tract_core::ops::nn as core_ops_nn;
use tract_gpu::tensor::DeviceTensorExt;
use tract_itertools::Itertools;

#[derive(Clone, Debug, Hash)]
pub struct MetalReduce {
    pub axes: TVec<usize>,
    pub reducer: Reducer,
}

impl MetalReduce {
    pub fn new(axes: TVec<usize>, reducer: Reducer) -> TractResult<Self> {
        ensure!(axes.len() == 1, "Only one axis of reduce is supported by MetalReduce");
        Ok(Self { axes, reducer })
    }

    pub fn from_tract_core(core_reduce: &core_ops_nn::Reduce) -> TractResult<Self> {
        let metal_reducer = match core_reduce.reducer {
            core_ops_nn::Reducer::Sum => Reducer::Sum,
            core_ops_nn::Reducer::MeanOfSquares => Reducer::MeanOfSquares,
            core_ops_nn::Reducer::Prod => Reducer::Prod,
            core_ops_nn::Reducer::Min => Reducer::Min,
            core_ops_nn::Reducer::Max => Reducer::Max,
            _ => bail!("Unsupported reducer {:?} on metal", core_reduce.reducer),
        };
        Self::new(core_reduce.axes.clone(), metal_reducer)
    }
}

impl Op for MetalReduce {
    fn name(&self) -> StaticName {
        format!("MetalReduce<{:?}>", self.reducer).into()
    }
    fn info(&self) -> TractResult<Vec<String>> {
        Ok(vec![format!("axes: {:?}", self.axes)])
    }
    op_as_typed_op!();
}

impl EvalOp for MetalReduce {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval_with_session(
        &self,
        node_id: usize,
        session: &SessionState,
        inputs: TVec<TValue>,
    ) -> TractResult<TVec<TValue>> {
        with_borrowed_metal_stream(|stream| {
            let opaque = args_1!(inputs);
            let input = opaque.to_device_tensor()?;
            let mut output_shape = input.shape().to_vec();
            output_shape[self.axes[0]] = 1;
            let output = tract_gpu::session_handler::make_tensor_for_node(
                session,
                node_id,
                input.datum_type(),
                &output_shape,
            )?;

            self.reducer.dispatch_eval(stream, input, self.axes[0], &output)?;

            Ok(tvec!(output.into_opaque_tensor().into_tvalue()))
        })
    }
}

impl TypedOp for MetalReduce {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        ensure!(self.axes.iter().tuple_windows().all(|(a, b)| a < b));
        tract_gpu::utils::facts_to_device_facts(inputs, |facts| {
            let mut shape: TVec<_> = facts[0].shape.to_tvec();
            for &ax in &self.axes {
                shape[ax] = 1.to_dim();
            }
            let dt = facts[0].datum_type;
            Ok(tvec!(dt.fact(shape)))
        })
        .with_context(|| format!("Error while computing facts for {:?}", self.name()))
    }

    as_op!();
}